Spaces:
Running
Running
Upload 26 files
Browse files- api.py +422 -509
- background_tasks.py +37 -110
- codegen_api.py +8 -17
- coder.py +139 -0
- image_to_3d_api.py +13 -25
- imagegen_api.py +7 -23
- main.py +9 -56
- model_loader.py +229 -725
- models.py +53 -42
- musicgen_api.py +9 -28
- sadtalker_api.py +16 -183
- sadtalker_utils.py +209 -820
- sentiment_api.py +7 -20
- stt_api.py +11 -27
- summarization_api.py +8 -21
- text_generation.py +93 -194
- text_to_video_api.py +13 -25
- tokenxxx.py +44 -114
- translation_api.py +7 -16
- tts_api.py +9 -20
- xtts_api.py +21 -0
- xxx.py +43 -114
api.py
CHANGED
@@ -1,509 +1,422 @@
|
|
1 |
-
from main import *
|
2 |
-
from tts_api import
|
3 |
-
from stt_api import
|
4 |
-
from sentiment_api import
|
5 |
-
from imagegen_api import
|
6 |
-
from musicgen_api import
|
7 |
-
from translation_api import
|
8 |
-
from codegen_api import
|
9 |
-
from text_to_video_api import
|
10 |
-
from summarization_api import
|
11 |
-
from image_to_3d_api import
|
12 |
-
from
|
13 |
-
from
|
14 |
-
import
|
15 |
-
import
|
16 |
-
import
|
17 |
-
import
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
<
|
32 |
-
<
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
margin-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
margin-bottom:
|
66 |
-
}
|
67 |
-
.
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
border
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
color:
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
background-color: #
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
clean_word = word.strip()
|
424 |
-
if clean_word:
|
425 |
-
yield "data: " + clean_word + "\n\n"
|
426 |
-
yield "data: <END_STREAM>\n\n"
|
427 |
-
break
|
428 |
-
return Response(event_stream(), mimetype="text/event-stream")
|
429 |
-
|
430 |
-
@app.route("/api/v1/generate", methods=["POST"])
|
431 |
-
def generate():
|
432 |
-
data = request.get_json()
|
433 |
-
text = data.get("text", "")
|
434 |
-
temp = float(data.get("temp", 0.7))
|
435 |
-
top_k = int(data.get("top_k", 40))
|
436 |
-
top_p = float(data.get("top_p", 0.0))
|
437 |
-
reppenalty = float(data.get("reppenalty", 1.2))
|
438 |
-
response_queue = queue.Queue()
|
439 |
-
reasoning_queue.put({
|
440 |
-
'text_input': text,
|
441 |
-
'temperature': temp,
|
442 |
-
'top_k': top_k,
|
443 |
-
'top_p': top_p,
|
444 |
-
'repetition_penalty': reppenalty,
|
445 |
-
'response_queue': response_queue
|
446 |
-
})
|
447 |
-
output = response_queue.get()
|
448 |
-
if "error" in output:
|
449 |
-
return jsonify({"error": output["error"]}), 500
|
450 |
-
result_text = output.get("text", "").strip()
|
451 |
-
return jsonify({"response": result_text})
|
452 |
-
|
453 |
-
@app.route("/api/v1/feedback", methods=["POST"])
|
454 |
-
def feedback():
|
455 |
-
data = request.get_json()
|
456 |
-
feedback_text = data.get("feedback_text")
|
457 |
-
correct_category = data.get("correct_category")
|
458 |
-
if feedback_text and correct_category:
|
459 |
-
feedback_queue.put((feedback_text, correct_category))
|
460 |
-
return jsonify({"status": "feedback received"})
|
461 |
-
return jsonify({"status": "feedback failed"}), 400
|
462 |
-
|
463 |
-
@app.route("/api/v1/tts", methods=["POST"])
|
464 |
-
def tts_api():
|
465 |
-
return tts_route()
|
466 |
-
|
467 |
-
@app.route("/api/v1/stt", methods=["POST"])
|
468 |
-
def stt_api():
|
469 |
-
return stt_route()
|
470 |
-
|
471 |
-
@app.route("/api/v1/sentiment", methods=["POST"])
|
472 |
-
def sentiment_api():
|
473 |
-
return sentiment_route()
|
474 |
-
|
475 |
-
@app.route("/api/v1/imagegen", methods=["POST"])
|
476 |
-
def imagegen_api():
|
477 |
-
return imagegen_route()
|
478 |
-
|
479 |
-
@app.route("/api/v1/musicgen", methods=["POST"])
|
480 |
-
def musicgen_api():
|
481 |
-
return musicgen_route()
|
482 |
-
|
483 |
-
@app.route("/api/v1/translation", methods=["POST"])
|
484 |
-
def translation_api():
|
485 |
-
return translation_route()
|
486 |
-
|
487 |
-
@app.route("/api/v1/codegen", methods=["POST"])
|
488 |
-
def codegen_api():
|
489 |
-
return codegen_route()
|
490 |
-
|
491 |
-
@app.route("/api/v1/text_to_video", methods=["POST"])
|
492 |
-
def text_to_video_api():
|
493 |
-
return text_to_video_route()
|
494 |
-
|
495 |
-
@app.route("/api/v1/summarization", methods=["POST"])
|
496 |
-
def summarization_api():
|
497 |
-
return summarization_route()
|
498 |
-
|
499 |
-
@app.route("/api/v1/image_to_3d", methods=["POST"])
|
500 |
-
def image_to_3d_api():
|
501 |
-
return image_to_3d_route()
|
502 |
-
|
503 |
-
@app.route("/api/v1/sadtalker", methods=["POST"])
|
504 |
-
def sadtalker():
|
505 |
-
from sadtalker_api import router as sadtalker_router
|
506 |
-
return sadtalker_router.create_video()
|
507 |
-
|
508 |
-
if __name__ == "__main__":
|
509 |
-
app.run(host="0.0.0.0", port=7860)
|
|
|
1 |
+
from main import *
|
2 |
+
from tts_api import tts_api as tts_module_api
|
3 |
+
from stt_api import stt_api as stt_module_api
|
4 |
+
from sentiment_api import sentiment_api as sentiment_module_api
|
5 |
+
from imagegen_api import imagegen_api as imagegen_module_api
|
6 |
+
from musicgen_api import musicgen_api as musicgen_module_api
|
7 |
+
from translation_api import translation_api as translation_module_api
|
8 |
+
from codegen_api import codegen_api as codegen_module_api
|
9 |
+
from text_to_video_api import text_to_video_api as text_to_video_module_api
|
10 |
+
from summarization_api import summarization_api as summarization_module_api
|
11 |
+
from image_to_3d_api import image_to_3d_api as image_to_3d_module_api
|
12 |
+
from xtts_api import xtts_api as xtts_module_api
|
13 |
+
from flask import Flask, request, jsonify, Response, send_file, stream_with_context
|
14 |
+
from flask_cors import CORS
|
15 |
+
import io
|
16 |
+
import queue
|
17 |
+
import base64
|
18 |
+
import gradio as gr
|
19 |
+
|
20 |
+
app = Flask(__name__)
|
21 |
+
CORS(app)
|
22 |
+
|
23 |
+
html_code = """<!DOCTYPE html>
|
24 |
+
<html lang="en">
|
25 |
+
<head>
|
26 |
+
<meta charset="UTF-8">
|
27 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
28 |
+
<title>AI Text Generation</title>
|
29 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
|
30 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
31 |
+
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
32 |
+
<style>
|
33 |
+
body {
|
34 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
35 |
+
background: #f0f0f0;
|
36 |
+
color: #333;
|
37 |
+
margin: 0;
|
38 |
+
padding: 0;
|
39 |
+
display: flex;
|
40 |
+
flex-direction: column;
|
41 |
+
align-items: center;
|
42 |
+
min-height: 100vh;
|
43 |
+
}
|
44 |
+
.container {
|
45 |
+
width: 95%;
|
46 |
+
max-width: 900px;
|
47 |
+
padding: 20px;
|
48 |
+
background-color: #fff;
|
49 |
+
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
50 |
+
border-radius: 8px;
|
51 |
+
margin-top: 20px;
|
52 |
+
margin-bottom: 20px;
|
53 |
+
display: flex;
|
54 |
+
flex-direction: column;
|
55 |
+
}
|
56 |
+
.header {
|
57 |
+
text-align: center;
|
58 |
+
margin-bottom: 20px;
|
59 |
+
}
|
60 |
+
.header h1 {
|
61 |
+
font-size: 2em;
|
62 |
+
color: #333;
|
63 |
+
}
|
64 |
+
.form-group {
|
65 |
+
margin-bottom: 15px;
|
66 |
+
}
|
67 |
+
.form-group textarea {
|
68 |
+
width: 100%;
|
69 |
+
padding: 10px;
|
70 |
+
border: 1px solid #ccc;
|
71 |
+
border-radius: 5px;
|
72 |
+
font-size: 16px;
|
73 |
+
box-sizing: border-box;
|
74 |
+
resize: vertical;
|
75 |
+
}
|
76 |
+
button {
|
77 |
+
padding: 10px 15px;
|
78 |
+
border: none;
|
79 |
+
border-radius: 5px;
|
80 |
+
background-color: #007bff;
|
81 |
+
color: white;
|
82 |
+
font-size: 18px;
|
83 |
+
cursor: pointer;
|
84 |
+
transition: background-color 0.3s ease;
|
85 |
+
}
|
86 |
+
button:hover {
|
87 |
+
background-color: #0056b3;
|
88 |
+
}
|
89 |
+
#output {
|
90 |
+
margin-top: 20px;
|
91 |
+
padding: 15px;
|
92 |
+
border: 1px solid #ddd;
|
93 |
+
border-radius: 5px;
|
94 |
+
background-color: #f9f9f9;
|
95 |
+
white-space: pre-wrap;
|
96 |
+
word-break: break-word;
|
97 |
+
overflow-y: auto;
|
98 |
+
max-height: 100vh;
|
99 |
+
}
|
100 |
+
#output strong {
|
101 |
+
font-weight: bold;
|
102 |
+
}
|
103 |
+
.animated-text {
|
104 |
+
position: fixed;
|
105 |
+
top: 20px;
|
106 |
+
left: 20px;
|
107 |
+
font-size: 1.5em;
|
108 |
+
color: rgba(0, 0, 0, 0.1);
|
109 |
+
pointer-events: none;
|
110 |
+
z-index: -1;
|
111 |
+
}
|
112 |
+
@media (max-width: 768px) {
|
113 |
+
.container {
|
114 |
+
width: 98%;
|
115 |
+
margin-top: 10px;
|
116 |
+
margin-bottom: 10px;
|
117 |
+
padding: 15px;
|
118 |
+
}
|
119 |
+
.header h1 {
|
120 |
+
font-size: 1.8em;
|
121 |
+
}
|
122 |
+
.form-group textarea, .form-group input[type="text"] {
|
123 |
+
font-size: 14px;
|
124 |
+
padding: 8px;
|
125 |
+
}
|
126 |
+
button {
|
127 |
+
font-size: 16px;
|
128 |
+
padding: 8px 12px;
|
129 |
+
}
|
130 |
+
#output {
|
131 |
+
font-size: 14px;
|
132 |
+
padding: 10px;
|
133 |
+
margin-top: 15px;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
</style>
|
137 |
+
</head>
|
138 |
+
<body>
|
139 |
+
<div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
|
140 |
+
<div class="container">
|
141 |
+
<div class="header animate__animated animate__fadeInDown">
|
142 |
+
</div>
|
143 |
+
<div class="form-group animate__animated animate__fadeInLeft">
|
144 |
+
<textarea id="text" rows="5" placeholder="Enter text"></textarea>
|
145 |
+
</div>
|
146 |
+
<button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
|
147 |
+
<div id="output" class="animate__animated">
|
148 |
+
<strong>Response:</strong><br>
|
149 |
+
<span id="generatedText"></span>
|
150 |
+
</div>
|
151 |
+
</div>
|
152 |
+
<script>
|
153 |
+
let eventSource = null;
|
154 |
+
let accumulatedText = "";
|
155 |
+
let lastResponse = "";
|
156 |
+
async function generateText() {
|
157 |
+
const inputText = document.getElementById("text").value;
|
158 |
+
document.getElementById("generatedText").innerText = "";
|
159 |
+
accumulatedText = "";
|
160 |
+
if (eventSource) {
|
161 |
+
eventSource.close();
|
162 |
+
}
|
163 |
+
const temp = 0.7;
|
164 |
+
const top_k_val = 40;
|
165 |
+
const top_p_val = 0.0;
|
166 |
+
const repetition_penalty_val = 1.2;
|
167 |
+
const requestData = {
|
168 |
+
text: inputText,
|
169 |
+
temp: temp,
|
170 |
+
top_k: top_k_val,
|
171 |
+
top_p: top_p_val,
|
172 |
+
reppenalty: repetition_penalty_val
|
173 |
+
};
|
174 |
+
const params = new URLSearchParams(requestData).toString();
|
175 |
+
eventSource = new EventSource('/api/v1/generate_stream?' + params);
|
176 |
+
eventSource.onmessage = function(event) {
|
177 |
+
if (event.data === "<END_STREAM>") {
|
178 |
+
eventSource.close();
|
179 |
+
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
|
180 |
+
if (currentResponse === lastResponse.trim()) {
|
181 |
+
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
|
182 |
+
} else {
|
183 |
+
lastResponse = currentResponse;
|
184 |
+
}
|
185 |
+
document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
|
186 |
+
return;
|
187 |
+
}
|
188 |
+
accumulatedText += event.data;
|
189 |
+
let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
|
190 |
+
document.getElementById("generatedText").innerHTML = marked.parse(partialText);
|
191 |
+
};
|
192 |
+
eventSource.onerror = function(error) {
|
193 |
+
console.error("SSE error", error);
|
194 |
+
eventSource.close();
|
195 |
+
};
|
196 |
+
const outputDiv = document.getElementById("output");
|
197 |
+
outputDiv.classList.add("show");
|
198 |
+
}
|
199 |
+
function base64ToBlob(base64Data, contentType) {
|
200 |
+
contentType = contentType || '';
|
201 |
+
const sliceSize = 1024;
|
202 |
+
const byteCharacters = atob(base64Data);
|
203 |
+
const bytesLength = byteCharacters.length;
|
204 |
+
const slicesCount = Math.ceil(bytesLength / sliceSize);
|
205 |
+
const byteArrays = new Array(slicesCount);
|
206 |
+
for (let sliceIndex = sliceIndex < slicesCount; ++sliceIndex) {
|
207 |
+
const begin = sliceIndex * sliceSize;
|
208 |
+
const end = Math.min(begin + sliceSize, bytesLength);
|
209 |
+
const bytes = new Array(end - begin);
|
210 |
+
for (let offset = begin, i = 0; offset < end; ++i, ++offset) {
|
211 |
+
bytes[i] = byteCharacters[offset].charCodeAt(0);
|
212 |
+
}
|
213 |
+
byteArrays[sliceIndex] = new Uint8Array(bytes);
|
214 |
+
}
|
215 |
+
return new Blob(byteArrays, { type: contentType });
|
216 |
+
}
|
217 |
+
</script>
|
218 |
+
</body>
|
219 |
+
</html>
|
220 |
+
"""
|
221 |
+
feedback_queue = queue.Queue()
|
222 |
+
|
223 |
+
|
224 |
+
@app.route("/")
|
225 |
+
def index():
|
226 |
+
return html_code
|
227 |
+
|
228 |
+
@app.route("/api/v1/generate_stream", methods=["GET"])
|
229 |
+
def generate_stream():
|
230 |
+
text = request.args.get("text", "")
|
231 |
+
temp = float(request.args.get("temp", 0.7))
|
232 |
+
top_k = int(request.args.get("top_k", 40))
|
233 |
+
top_p = float(request.args.get("top_p", 0.0))
|
234 |
+
reppenalty = float(request.args.get("reppenalty", 1.2))
|
235 |
+
response_queue = queue.Queue()
|
236 |
+
reasoning_queue.put({
|
237 |
+
'text_input': text,
|
238 |
+
'temperature': temp,
|
239 |
+
'top_k': top_k,
|
240 |
+
'top_p': top_p,
|
241 |
+
'repetition_penalty': reppenalty,
|
242 |
+
'response_queue': response_queue
|
243 |
+
})
|
244 |
+
@stream_with_context
|
245 |
+
def event_stream():
|
246 |
+
while True:
|
247 |
+
output = response_queue.get()
|
248 |
+
if "error" in output:
|
249 |
+
yield "data: <ERROR>\n\n"
|
250 |
+
break
|
251 |
+
text_chunk = output.get("text")
|
252 |
+
if text_chunk:
|
253 |
+
for word in text_chunk.split(' '):
|
254 |
+
clean_word = word.strip()
|
255 |
+
if clean_word:
|
256 |
+
yield "data: " + clean_word + "\n\n"
|
257 |
+
yield "data: <END_STREAM>\n\n"
|
258 |
+
break
|
259 |
+
return Response(event_stream(), mimetype="text/event-stream")
|
260 |
+
|
261 |
+
@app.route("/api/v1/generate", methods=["POST"])
|
262 |
+
def generate():
|
263 |
+
data = request.get_json()
|
264 |
+
text = data.get("text", "")
|
265 |
+
temp = float(data.get("temp", 0.7))
|
266 |
+
top_k = int(data.get("top_k", 40))
|
267 |
+
top_p = float(data.get("top_p", 0.0))
|
268 |
+
reppenalty = float(data.get("reppenalty", 1.2))
|
269 |
+
response_queue = queue.Queue()
|
270 |
+
reasoning_queue.put({
|
271 |
+
'text_input': text,
|
272 |
+
'temperature': temp,
|
273 |
+
'top_k': top_k,
|
274 |
+
'top_p': top_p,
|
275 |
+
'repetition_penalty': reppenalty,
|
276 |
+
'response_queue': response_queue
|
277 |
+
})
|
278 |
+
output = response_queue.get()
|
279 |
+
if "error" in output:
|
280 |
+
return jsonify({"error": output["error"]}), 500
|
281 |
+
result_text = output.get("text", "").strip()
|
282 |
+
return jsonify({"response": result_text})
|
283 |
+
|
284 |
+
@app.route("/api/v1/feedback", methods=["POST"])
|
285 |
+
def feedback():
|
286 |
+
data = request.get_json()
|
287 |
+
feedback_text = data.get("feedback_text")
|
288 |
+
correct_category = data.get("correct_category")
|
289 |
+
if feedback_text and correct_category:
|
290 |
+
feedback_queue.put((feedback_text, correct_category))
|
291 |
+
return jsonify({"status": "feedback received"})
|
292 |
+
return jsonify({"status": "feedback failed"}), 400
|
293 |
+
|
294 |
+
@app.route("/api/v1/tts", methods=["POST"])
|
295 |
+
def tts_api():
|
296 |
+
return tts_module_api()
|
297 |
+
|
298 |
+
@app.route("/api/v1/stt", methods=["POST"])
|
299 |
+
def stt_api():
|
300 |
+
return stt_module_api()
|
301 |
+
|
302 |
+
@app.route("/api/v1/sentiment", methods=["POST"])
|
303 |
+
def sentiment_api():
|
304 |
+
return sentiment_module_api()
|
305 |
+
|
306 |
+
@app.route("/api/v1/imagegen", methods=["POST"])
|
307 |
+
def imagegen_api():
|
308 |
+
return imagegen_module_api()
|
309 |
+
|
310 |
+
@app.route("/api/v1/musicgen", methods=["POST"])
|
311 |
+
def musicgen_api():
|
312 |
+
return musicgen_module_api()
|
313 |
+
|
314 |
+
@app.route("/api/v1/translation", methods=["POST"])
|
315 |
+
def translation_api():
|
316 |
+
return translation_module_api()
|
317 |
+
|
318 |
+
@app.route("/api/v1/codegen", methods=["POST"])
|
319 |
+
def codegen_api():
|
320 |
+
return codegen_module_api()
|
321 |
+
|
322 |
+
@app.route("/api/v1/text_to_video", methods=["POST"])
|
323 |
+
def text_to_video_api():
|
324 |
+
return text_to_video_module_api()
|
325 |
+
|
326 |
+
@app.route("/api/v1/summarization", methods=["POST"])
|
327 |
+
def summarization_api():
|
328 |
+
return summarization_module_api()
|
329 |
+
|
330 |
+
@app.route("/api/v1/image_to_3d", methods=["POST"])
|
331 |
+
def image_to_3d_api():
|
332 |
+
return image_to_3d_module_api()
|
333 |
+
|
334 |
+
@app.route("/api/v1/xtts_clone", methods=["POST"])
|
335 |
+
def xtts_clone_api():
|
336 |
+
return xtts_module_api()
|
337 |
+
|
338 |
+
@app.route("/api/v1/sadtalker", methods=["POST"])
|
339 |
+
def sadtalker():
|
340 |
+
from sadtalker_api import router as sadtalker_router
|
341 |
+
return sadtalker_router.create_video()
|
342 |
+
|
343 |
+
if __name__ == "__main__":
|
344 |
+
with gr.Blocks() as demo:
|
345 |
+
gr.Markdown("## AI Powerhouse")
|
346 |
+
with gr.Tab("Text Generation"):
|
347 |
+
text_input = gr.Textbox(lines=5, placeholder="Enter text")
|
348 |
+
text_output = gr.Markdown()
|
349 |
+
text_button = gr.Button("Generate Text")
|
350 |
+
text_button.click(generate, inputs=text_input, outputs=text_output)
|
351 |
+
|
352 |
+
with gr.Tab("Image Generation"):
|
353 |
+
image_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for image")
|
354 |
+
image_output = gr.Image()
|
355 |
+
image_button = gr.Button("Generate Image")
|
356 |
+
image_button.click(imagegen_api, inputs=image_text_input, outputs=image_output)
|
357 |
+
|
358 |
+
with gr.Tab("Music Generation"):
|
359 |
+
music_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for music")
|
360 |
+
music_output = gr.Audio()
|
361 |
+
music_button = gr.Button("Generate Music")
|
362 |
+
music_button.click(musicgen_api, inputs=music_text_input, outputs=music_output)
|
363 |
+
|
364 |
+
with gr.Tab("Code Generation"):
|
365 |
+
code_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for code")
|
366 |
+
code_output = gr.File()
|
367 |
+
code_button = gr.Button("Generate Code")
|
368 |
+
code_button.click(codegen_api, inputs=code_text_input, outputs=code_output)
|
369 |
+
|
370 |
+
with gr.Tab("Text to Video"):
|
371 |
+
video_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for video")
|
372 |
+
video_output = gr.Video()
|
373 |
+
video_button = gr.Button("Generate Video")
|
374 |
+
video_button.click(text_to_video_api, inputs=video_text_input, outputs=video_output)
|
375 |
+
|
376 |
+
with gr.Tab("Summarization"):
|
377 |
+
summary_text_input = gr.Textbox(lines=5, placeholder="Enter text to summarize")
|
378 |
+
summary_output = gr.Textbox()
|
379 |
+
summary_button = gr.Button("Summarize")
|
380 |
+
summary_button.click(summarization_api, inputs=summary_text_input, outputs=summary_output)
|
381 |
+
|
382 |
+
with gr.Tab("Translation"):
|
383 |
+
translate_text_input = gr.Textbox(lines=3, placeholder="Enter text to translate")
|
384 |
+
translate_lang_dropdown = gr.Dropdown(['es', 'en', 'fr', 'de'], value='es', label="Target Language")
|
385 |
+
translation_output = gr.Textbox()
|
386 |
+
translate_button = gr.Button("Translate")
|
387 |
+
translate_button.click(translation_api, inputs=[translate_text_input, translate_lang_dropdown], outputs=translation_output)
|
388 |
+
|
389 |
+
with gr.Tab("Sentiment Analysis"):
|
390 |
+
sentiment_text_input = gr.Textbox(lines=3, placeholder="Enter text for sentiment analysis")
|
391 |
+
sentiment_output = gr.Textbox()
|
392 |
+
sentiment_button = gr.Button("Analyze Sentiment")
|
393 |
+
sentiment_button.click(sentiment_api, inputs=sentiment_text_input, outputs=sentiment_output)
|
394 |
+
|
395 |
+
with gr.Tab("Text to Speech"):
|
396 |
+
tts_text_input = gr.Textbox(lines=3, placeholder="Enter text for speech")
|
397 |
+
tts_output = gr.Audio()
|
398 |
+
tts_button = gr.Button("Generate Speech")
|
399 |
+
tts_button.click(tts_api, inputs=tts_text_input, outputs=tts_output)
|
400 |
+
|
401 |
+
with gr.Tab("Voice Cloning (XTTS)"):
|
402 |
+
xtts_text_input = gr.Textbox(lines=3, placeholder="Enter text for voice cloning")
|
403 |
+
xtts_audio_input = gr.Audio(source="upload", type="filepath", label="Reference Audio for Voice Cloning")
|
404 |
+
xtts_output = gr.Audio()
|
405 |
+
xtts_button = gr.Button("Clone Voice")
|
406 |
+
xtts_button.click(xtts_module_api, inputs=[xtts_text_input, xtts_audio_input], outputs=xtts_output)
|
407 |
+
|
408 |
+
with gr.Tab("Speech to Text"):
|
409 |
+
stt_audio_input = gr.Audio(source="microphone", type="filepath")
|
410 |
+
stt_output = gr.Textbox()
|
411 |
+
stt_button = gr.Button("Transcribe Speech")
|
412 |
+
stt_button.click(stt_api, inputs=stt_audio_input, outputs=stt_output)
|
413 |
+
|
414 |
+
with gr.Tab("Image to 3D"):
|
415 |
+
image_3d_input = gr.Image(source="upload", type="filepath")
|
416 |
+
model_3d_output = gr.File()
|
417 |
+
image_3d_button = gr.Button("Generate 3D Model")
|
418 |
+
image_3d_button.click(image_to_3d_api, inputs=image_3d_input, outputs=model_3d_output)
|
419 |
+
|
420 |
+
app = gr.routes.App(demo)
|
421 |
+
|
422 |
+
app.run(host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
background_tasks.py
CHANGED
@@ -1,18 +1,9 @@
|
|
1 |
-
import time
|
2 |
-
import threading
|
3 |
-
import queue
|
4 |
-
import uuid
|
5 |
-
import unicodedata
|
6 |
-
import re
|
7 |
from deep_translator import GoogleTranslator
|
8 |
from duckduckgo_search import DDGS
|
9 |
-
import nltk
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
import math
|
13 |
|
14 |
nltk.download('punkt')
|
15 |
-
|
16 |
categories = ['News', 'Sports', 'Entertainment']
|
17 |
TEXT_GENERATION_RATE = 10
|
18 |
text_queue = queue.Queue()
|
@@ -25,7 +16,7 @@ news_clf = None
|
|
25 |
|
26 |
class SimpleClassifier(nn.Module):
|
27 |
def __init__(self, vocab_size, num_classes, embedding_dim=128):
|
28 |
-
super(
|
29 |
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
30 |
self.fc = nn.Linear(embedding_dim, num_classes)
|
31 |
def forward(self, x):
|
@@ -34,91 +25,46 @@ class SimpleClassifier(nn.Module):
|
|
34 |
out = self.fc(pooled)
|
35 |
return out
|
36 |
|
37 |
-
def tokenize_text(text):
|
38 |
-
|
39 |
-
|
40 |
-
def update_vocabulary(tokens):
|
41 |
-
global vocabulary, word_to_index
|
42 |
-
for token in tokens:
|
43 |
-
if token not in word_to_index:
|
44 |
-
word_to_index[token] = len(vocabulary)
|
45 |
-
vocabulary.append(token)
|
46 |
-
|
47 |
-
def text_to_vector(text):
|
48 |
-
tokens = tokenize_text(text)
|
49 |
-
update_vocabulary(tokens)
|
50 |
-
indices = [word_to_index.get(token, 0) for token in tokens]
|
51 |
-
return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
|
52 |
|
53 |
def generate_and_queue_text(language):
|
54 |
global categories, text_queue
|
55 |
-
num_categories = len(categories)
|
56 |
-
num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
|
57 |
while True:
|
58 |
for category in categories:
|
59 |
for _ in range(num_texts_per_category):
|
60 |
-
uid = uuid.uuid4()
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
text = translator.translate(base_text)
|
65 |
-
except Exception:
|
66 |
-
text = base_text
|
67 |
-
processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable())
|
68 |
-
text_queue.put((processed_text, category))
|
69 |
-
time.sleep(0)
|
70 |
|
71 |
def background_training():
|
72 |
global categories, news_clf, feedback_queue, vocabulary
|
73 |
-
if categories is None:
|
74 |
-
|
75 |
-
|
76 |
-
learning_rate =
|
77 |
-
epochs = 1
|
78 |
-
if news_clf is None:
|
79 |
-
news_clf = SimpleClassifier(len(vocabulary), num_classes)
|
80 |
-
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
|
81 |
-
criterion = nn.CrossEntropyLoss()
|
82 |
while True:
|
83 |
try:
|
84 |
feedback_item = feedback_queue.get(timeout=10)
|
85 |
if feedback_item:
|
86 |
-
input_text, generated_text = feedback_item
|
87 |
-
|
88 |
-
if len(vocabulary)
|
89 |
-
|
90 |
-
|
91 |
-
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
|
92 |
-
if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0:
|
93 |
-
news_clf = SimpleClassifier(len(vocabulary), num_classes)
|
94 |
-
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
|
95 |
-
input_vector = text_to_vector(input_text)
|
96 |
-
tokens = tokenize_text(input_text)
|
97 |
-
update_vocabulary(tokens)
|
98 |
-
tokens_indices = [word_to_index.get(word, 0) for word in tokens]
|
99 |
-
input_tensor = torch.tensor([tokens_indices], dtype=torch.long)
|
100 |
-
target_index = categories.index(generated_text) if generated_text in categories else 0
|
101 |
target_category_index = torch.tensor([target_index], dtype=torch.long)
|
102 |
-
if num_classes <= 1:
|
103 |
-
|
104 |
-
news_clf.fc = nn.Linear(128, num_classes)
|
105 |
-
for _ in range(epochs):
|
106 |
-
optimizer.zero_grad()
|
107 |
-
output = news_clf(input_tensor)
|
108 |
-
loss = criterion(output, target_category_index)
|
109 |
-
loss.backward()
|
110 |
-
optimizer.step()
|
111 |
feedback_queue.task_done()
|
112 |
-
except queue.Empty:
|
113 |
-
|
114 |
-
except Exception:
|
115 |
-
time.sleep(5)
|
116 |
|
117 |
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
|
118 |
for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
|
119 |
-
if token == "<END_STREAM>":
|
120 |
-
yield "<END_STREAM>"
|
121 |
-
break
|
122 |
yield token + " "
|
123 |
|
124 |
def background_reasoning_queue():
|
@@ -126,40 +72,21 @@ def background_reasoning_queue():
|
|
126 |
while True:
|
127 |
try:
|
128 |
item = reasoning_queue.get(timeout=1)
|
129 |
-
if item is None:
|
130 |
-
|
131 |
-
continue
|
132 |
-
text_input = item.get('text_input')
|
133 |
-
temperature = item.get('temperature', 0.7)
|
134 |
-
top_k = item.get('top_k', 40)
|
135 |
-
top_p = item.get('top_p', 0.0)
|
136 |
-
repetition_penalty = item.get('repetition_penalty', 1.2)
|
137 |
resp_queue = item.get('response_queue', queue.Queue())
|
138 |
-
if not text_input:
|
139 |
-
resp_queue.put({"error": "Empty text input received."})
|
140 |
-
reasoning_queue.task_done()
|
141 |
-
continue
|
142 |
generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
|
143 |
-
full_response = ""
|
144 |
for chunk in generated_text_stream:
|
145 |
-
if chunk == "<END_STREAM>":
|
146 |
-
break
|
147 |
full_response += chunk
|
148 |
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
|
149 |
-
if cleaned_response in seen_responses:
|
150 |
-
|
151 |
-
resp_queue.put({"text": final_response})
|
152 |
-
else:
|
153 |
-
seen_responses.add(cleaned_response)
|
154 |
-
final_response = cleaned_response
|
155 |
-
resp_queue.put({"text": final_response})
|
156 |
reasoning_queue.task_done()
|
157 |
-
except queue.Empty:
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
except Exception:
|
163 |
-
pass
|
164 |
-
if reasoning_queue and not reasoning_queue.empty():
|
165 |
-
reasoning_queue.task_done()
|
|
|
1 |
+
import time, threading, queue, uuid, unicodedata, re
|
|
|
|
|
|
|
|
|
|
|
2 |
from deep_translator import GoogleTranslator
|
3 |
from duckduckgo_search import DDGS
|
4 |
+
import nltk, torch, torch.nn as nn
|
|
|
|
|
|
|
5 |
|
6 |
nltk.download('punkt')
|
|
|
7 |
categories = ['News', 'Sports', 'Entertainment']
|
8 |
TEXT_GENERATION_RATE = 10
|
9 |
text_queue = queue.Queue()
|
|
|
16 |
|
17 |
class SimpleClassifier(nn.Module):
|
18 |
def __init__(self, vocab_size, num_classes, embedding_dim=128):
|
19 |
+
super().__init__()
|
20 |
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
21 |
self.fc = nn.Linear(embedding_dim, num_classes)
|
22 |
def forward(self, x):
|
|
|
25 |
out = self.fc(pooled)
|
26 |
return out
|
27 |
|
28 |
+
def tokenize_text(text): return nltk.word_tokenize(text)
|
29 |
+
def update_vocabulary(tokens): global vocabulary, word_to_index; for token in tokens: if token not in word_to_index: word_to_index[token] = len(vocabulary); vocabulary.append(token)
|
30 |
+
def text_to_vector(text): tokens = tokenize_text(text); update_vocabulary(tokens); indices = [word_to_index.get(token, 0) for token in tokens]; return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def generate_and_queue_text(language):
|
33 |
global categories, text_queue
|
34 |
+
num_categories = len(categories); num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
|
|
|
35 |
while True:
|
36 |
for category in categories:
|
37 |
for _ in range(num_texts_per_category):
|
38 |
+
uid = uuid.uuid4(); base_text = f"Category: {category}. ID:{uid}"
|
39 |
+
try: translator = GoogleTranslator(source='auto', target=language); text = translator.translate(base_text)
|
40 |
+
except: text = base_text
|
41 |
+
processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable()); text_queue.put((processed_text, category)); time.sleep(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def background_training():
|
44 |
global categories, news_clf, feedback_queue, vocabulary
|
45 |
+
if categories is None: categories = ['DefaultCategory']
|
46 |
+
num_classes = len(categories); learning_rate = 0.01; epochs = 1
|
47 |
+
if news_clf is None: news_clf = SimpleClassifier(len(vocabulary), num_classes)
|
48 |
+
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
|
49 |
while True:
|
50 |
try:
|
51 |
feedback_item = feedback_queue.get(timeout=10)
|
52 |
if feedback_item:
|
53 |
+
input_text, generated_text = feedback_item; input_vector = text_to_vector(input_text)
|
54 |
+
if len(vocabulary) == 0: vocabulary.extend(["<PAD>", "<EOS>"]); news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
|
55 |
+
if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0: news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); input_vector = text_to_vector(input_text)
|
56 |
+
tokens = tokenize_text(input_text); update_vocabulary(tokens); tokens_indices = [word_to_index.get(word, 0) for word in tokens]
|
57 |
+
input_tensor = torch.tensor([tokens_indices], dtype=torch.long); target_index = categories.index(generated_text) if generated_text in categories else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
target_category_index = torch.tensor([target_index], dtype=torch.long)
|
59 |
+
if num_classes <= 1: num_classes = 2; news_clf.fc = nn.Linear(128, num_classes)
|
60 |
+
for _ in range(epochs): optimizer.zero_grad(); output = news_clf(input_tensor); loss = criterion(output, target_category_index); loss.backward(); optimizer.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
feedback_queue.task_done()
|
62 |
+
except queue.Empty: pass
|
63 |
+
except: time.sleep(5)
|
|
|
|
|
64 |
|
65 |
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
|
66 |
for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
|
67 |
+
if token == "<END_STREAM>": yield "<END_STREAM>"; break
|
|
|
|
|
68 |
yield token + " "
|
69 |
|
70 |
def background_reasoning_queue():
|
|
|
72 |
while True:
|
73 |
try:
|
74 |
item = reasoning_queue.get(timeout=1)
|
75 |
+
if item is None: reasoning_queue.task_done(); continue
|
76 |
+
text_input = item.get('text_input'); temperature = item.get('temperature', 0.7); top_k = item.get('top_k', 40); top_p = item.get('top_p', 0.0); repetition_penalty = item.get('repetition_penalty', 1.2)
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
resp_queue = item.get('response_queue', queue.Queue())
|
78 |
+
if not text_input: resp_queue.put({"error": "Empty text input received."}); reasoning_queue.task_done(); continue
|
|
|
|
|
|
|
79 |
generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
|
80 |
+
full_response = "";
|
81 |
for chunk in generated_text_stream:
|
82 |
+
if chunk == "<END_STREAM>": break
|
|
|
83 |
full_response += chunk
|
84 |
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
|
85 |
+
if cleaned_response in seen_responses: final_response = "**Response is repetitive. Please try again or rephrase your query.**"; resp_queue.put({"text": final_response})
|
86 |
+
else: seen_responses.add(cleaned_response); final_response = cleaned_response; resp_queue.put({"text": final_response})
|
|
|
|
|
|
|
|
|
|
|
87 |
reasoning_queue.task_done()
|
88 |
+
except queue.Empty: pass
|
89 |
+
except Exception as e:
|
90 |
+
try: resp_queue.put({"error": str(e)})
|
91 |
+
except: pass
|
92 |
+
if reasoning_queue and not reasoning_queue.empty(): reasoning_queue.task_done()
|
|
|
|
|
|
|
|
codegen_api.py
CHANGED
@@ -1,22 +1,13 @@
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
|
|
3 |
|
4 |
def generate_code(prompt, output_path="output_code.py"):
|
5 |
-
if codegen_model is None or codegen_tokenizer is None:
|
6 |
-
|
7 |
-
|
8 |
-
output = codegen_model.generate(input_ids, max_length=2048, temperature=0.7, top_p=0.9)
|
9 |
-
code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
|
10 |
-
with open(output_path, "w") as file:
|
11 |
-
file.write(code)
|
12 |
-
return output_path
|
13 |
|
14 |
-
def codegen_api():
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
return jsonify({"error": "Prompt is required"}), 400
|
19 |
-
output_file = generate_code(prompt)
|
20 |
-
if output_file == "Code generation model or tokenizer not initialized.":
|
21 |
-
return jsonify({"error": "Code generation failed"}), 500
|
22 |
-
return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
|
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
+
import io, base64
|
4 |
|
5 |
def generate_code(prompt, output_path="output_code.py"):
|
6 |
+
if codegen_model is None or codegen_tokenizer is None: return {"error": "Code generation model or tokenizer not initialized."}
|
7 |
+
input_ids = codegen_tokenizer(prompt, return_tensors='pt').to(device); output = codegen_model.generate(input_ids, max_length=2048, temperature=0.7, top_p=0.9)
|
8 |
+
code = codegen_tokenizer.decode(output[0], skip_special_tokens=True); return {"code": code}
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
def codegen_api(prompt):
|
11 |
+
output = generate_code(prompt)
|
12 |
+
if "error" in output: return {"error": output["error"]}
|
13 |
+
code_base64 = base64.b64encode(output['code'].encode('utf-8')).decode('utf-8'); return {"code_base64": code_base64, "mimetype": "text/x-python", "filename": "output.py"}
|
|
|
|
|
|
|
|
|
|
coder.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dash
|
2 |
+
from dash import Dash, html, dcc, callback, Output, Input, State
|
3 |
+
from dash.exceptions import PreventUpdate
|
4 |
+
import base64, uuid, requests, io, re
|
5 |
+
|
6 |
+
app_style = None
|
7 |
+
external_stylesheets = ['style.css']
|
8 |
+
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
|
9 |
+
app.layout = html.Div([dcc.Location(id='url', refresh=False), html.Div(id='page-content')])
|
10 |
+
index_page = html.Div([html.H1("AI Powerhouse", className='index-title animate__animated animate__fadeInDown'), html.Div([dcc.Link('Text Generation', href='/text-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Audio Video', href='/audio-video', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Image Generation', href='/image-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Image to 3D', href='/image-to-3d', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Text to Video', href='/text-to-video', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Music Generation', href='/music-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Sentiment Analysis', href='/sentiment-analysis', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Translation', href='/translation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Code Generation', href='/code-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Summarization', href='/summarization', className='index-link animate__animated animate__fadeInUp'), dcc.Link('SadTalker', href='/sad-talker', className='index-link animate__animated animate__fadeInUp')], className='index-links-container')], className='index-page page-layout')
|
11 |
+
text_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Text Generation'), html.Div(className='chat-container page-layout', children=[html.Div(className='chat-header animate__animated animate__fadeInDown', children="Text Generation Interface"), html.Div(className='chat-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='text-input', placeholder='Enter text prompt...', rows=4, className='chat-text-area'), html.Button('Generate', id='generate-button', n_clicks=0, className='chat-generate-button')]), html.Div(id='output', className='chat-output animate__animated animate__fadeInUp', children=[html.Div(className='response-header', children="Response:"), dcc.Markdown(id='generated-text', className='response-text')]), dcc.Link('Back to Home', href='/', className='chat-back-link')])], className='page-layout text-gen-layout')
|
12 |
+
audio_video_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Audio & Video Tools'), html.Div(className='av-container page-layout', children=[html.Div(className='av-header animate__animated animate__fadeInDown', children="Audio & Video Processing"), html.Div(className='av-upload-section animate__animated animate__fadeInLeft', children=[html.Div(className='upload-box', children=[dcc.Upload(id='upload-audio', children=html.Div(['Drag and Drop or ', html.A('Select Audio')]), className='upload-area'), html.Div(id='audio-output-text', className='upload-output', children="STT Output")]), html.Div(className='upload-box', children=[dcc.Upload(id='upload-image', children=html.Div(['Drag and Drop or ', html.A('Select Image')]), className='upload-area'), html.Div(id='image-output', className='upload-output', children="Image Uploaded")])]), html.Div(id='video-output', className='av-video-output animate__animated animate__fadeInUp', children="Video Output Area"), dcc.Link('Back to Home', href='/', className='av-back-link')])], className='page-layout av-layout')
|
13 |
+
image_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Image Generation'), html.Div(className='imagegen-container page-layout', children=[html.Div(className='imagegen-header animate__animated animate__fadeInDown', children="Image Generation Interface"), html.Div(className='imagegen-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='imagegen-text-input', placeholder='Enter prompt for image...', rows=4, className='imagegen-text-area'), html.Button('Generate Image', id='generate-image-button', n_clicks=0, className='imagegen-generate-button')]), html.Div(id='image-output-display', className='imagegen-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='imagegen-back-link')])], className='page-layout imagegen-layout')
|
14 |
+
image_to_3d_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Image to 3D Conversion'), html.Div(className='imagetod-container page-layout', children=[html.Div(className='imagetod-header animate__animated animate__fadeInDown', children="Image to 3D Model Conversion"), html.Div(className='imagetod-upload animate__animated animate__fadeInLeft', children=[dcc.Upload(id='upload-image-3d', children=html.Div(['Drag and Drop or ', html.A('Select Image')]), className='imagetod-upload-area')]), html.Div(id='3d-model-output', className='imagetod-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='imagetod-back-link')])], className='page-layout imagetod-layout')
|
15 |
+
text_to_video_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Text to Video Generation'), html.Div(className='textvideo-container page-layout', children=[html.Div(className='textvideo-header animate__animated animate__fadeInDown', children="Text to Video Generation Interface"), html.Div(className='textvideo-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='text-video-input', placeholder='Enter prompt for video...', rows=4, className='textvideo-text-area'), html.Button('Generate Video', id='generate-video-button', n_clicks=0, className='textvideo-generate-button')]), html.Div(id='video-gen-output', className='textvideo-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='textvideo-back-link')])], className='page-layout textvideo-layout')
|
16 |
+
music_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Music Generation'), html.Div(className='musicgen-container page-layout', children=[html.Div(className='musicgen-header animate__animated animate__fadeInDown', children="Music Generation Interface"), html.Div(className='musicgen-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='musicgen-text-input', placeholder='Enter prompt for music...', rows=4, className='musicgen-text-area'), html.Button('Generate Music', id='generate-music-button', n_clicks=0, className='musicgen-generate-button')]), html.Div(id='music-output-display', className='musicgen-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='musicgen-back-link')])], className='page-layout musicgen-layout')
|
17 |
+
sentiment_analysis_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Sentiment Analysis'), html.Div(className='sentiment-container page-layout', children=[html.Div(className='sentiment-header animate__animated animate__fadeInDown', children="Sentiment Analysis Tool"), html.Div(className='sentiment-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='sentiment-text-input', placeholder='Enter text for analysis...', rows=4, className='sentiment-text-area'), html.Button('Analyze Sentiment', id='analyze-sentiment-button', n_clicks=0, className='sentiment-analyze-button')]), html.Div(id='sentiment-output-display', className='sentiment-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='sentiment-back-link')])], className='page-layout sentiment-layout')
|
18 |
+
translation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Translation Services'), html.Div(className='translation-container page-layout', children=[html.Div(className='translation-header animate__animated animate__fadeInDown', children="Translation Interface"), html.Div(className='translation-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='translate-text-input', placeholder='Enter text to translate...', rows=4, className='translation-text-area'), dcc.Dropdown(id='target-language-dropdown', options=[{'label': 'Spanish', 'value': 'es'},{'label': 'English', 'value': 'en'},{'label': 'French', 'value': 'fr'},{'label': 'German', 'value': 'de'}], value='es', className='translation-dropdown'), html.Button('Translate', id='translate-button', n_clicks=0, className='translation-translate-button')]), html.Div(id='translation-output-display', className='translation-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='translation-back-link')])], className='page-layout translation-layout')
|
19 |
+
code_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Code Generation'), html.Div(className='codegen-container page-layout', children=[html.Div(className='codegen-header animate__animated animate__fadeInDown', children="Code Generation Interface"), html.Div(className='codegen-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='codegen-text-input', placeholder='Enter prompt for code...', rows=4, className='codegen-text-area'), html.Button('Generate Code', id='generate-code-button', n_clicks=0, className='codegen-generate-button')]), html.Div(id='codegen-output-display', className='codegen-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='codegen-back-link')])], className='page-layout codegen-layout')
|
20 |
+
summarization_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Text Summarization'), html.Div(className='summarization-container page-layout', children=[html.Div(className='summarization-header animate__animated animate__fadeInDown', children="Text Summarization Tool"), html.Div(className='summarization-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='summarize-text-input', placeholder='Enter text to summarize...', rows=4, className='summarization-text-area'), html.Button('Summarize', id='summarize-button', n_clicks=0, className='summarization-summarize-button')]), html.Div(id='summarization-output-display', className='summarization-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='summarization-back-link')])], className='page-layout summarization-layout')
|
21 |
+
sadtalker_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED SadTalker'), html.Div(className='sadtalker-container page-layout', children=[html.Div(className='sadtalker-header animate__animated animate__fadeInDown', children="SadTalker Interface"), html.Div(className='sadtalker-upload animate__animated animate__fadeInLeft', children=[dcc.Upload(id='upload-sadtalker-image', children=html.Div(['Drag and Drop Image', html.Br(), html.I(className="fas fa-image upload-icon")]), className='sadtalker-upload-area', multiple=False), dcc.Upload(id='upload-sadtalker-audio', children=html.Div(['Drag and Drop Audio', html.Br(), html.I(className="fas fa-file-audio upload-icon")]), className='sadtalker-upload-area', multiple=False)]), html.Div(id='sadtalker-video-output', className='sadtalker-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='sadtalker-back-link')])], className='page-layout sadtalker-layout')
|
22 |
+
|
23 |
+
@app.callback(Output('page-content', 'children'), [Input('url', 'pathname')])
|
24 |
+
def display_page(pathname):
|
25 |
+
if pathname == '/text-generation': return text_generation_layout
|
26 |
+
elif pathname == '/audio-video': return audio_video_layout
|
27 |
+
elif pathname == '/image-generation': return image_generation_layout
|
28 |
+
elif pathname == '/image-to-3d': return image_to_3d_layout
|
29 |
+
elif pathname == '/text-to-video': return text_to_video_layout
|
30 |
+
elif pathname == '/music-generation': return music_generation_layout
|
31 |
+
elif pathname == '/sentiment-analysis': return sentiment_analysis_layout
|
32 |
+
elif pathname == '/translation': return translation_layout
|
33 |
+
elif pathname == '/code-generation': return code_generation_layout
|
34 |
+
elif pathname == '/summarization': return summarization_layout
|
35 |
+
elif pathname == '/sad-talker': return sadtalker_layout
|
36 |
+
else: return index_page
|
37 |
+
|
38 |
+
@app.callback(Output('generated-text', 'children'), Output('output', 'className'), Input('generate-button', 'n_clicks'), State('text-input', 'value'), prevent_initial_call=True)
|
39 |
+
def generate_reasoning_dash(n_clicks, text_input):
|
40 |
+
if not text_input: return "Please enter text.", 'chat-output animate__animated animate__fadeInUp error-output'
|
41 |
+
api_url = "/api/v1/generate"; payload = {"text": text_input}
|
42 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); data = response.json(); generated_text = data.get("response", "Error in backend response."); return generated_text, 'chat-output animate__animated animate__fadeInUp'
|
43 |
+
except requests.exceptions.RequestException as e: return f"Error communicating with backend: {e}", 'chat-output animate__animated animate__fadeInUp error-output'
|
44 |
+
|
45 |
+
@app.callback(Output('audio-output-text', 'children'), Output('video-output', 'children'), Output('audio-output-text', 'className'), Output('video-output', 'className'), Input('upload-audio', 'contents'), State('upload-audio', 'filename'), Input('upload-image', 'contents'), State('upload-image', 'filename'), prevent_initial_call=True)
|
46 |
+
def process_audio_video_dash(audio_contents, audio_filename, image_contents, image_filename):
|
47 |
+
stt_output_text = ""; video_display = ""; stt_class = 'upload-output'; video_class = 'av-video-output animate__animated animate__fadeInUp'
|
48 |
+
if audio_contents:
|
49 |
+
try: content_type, content_string = audio_contents.split(','); decoded_audio = base64.b64decode(content_string); audio_io = io.BytesIO(decoded_audio); files = {'audio': (audio_filename, audio_io, content_type)}; response = requests.post("http://127.0.0.1:7860/api/v1/stt", files=files); response.raise_for_status(); data = response.json(); stt_output_text = f"STT Output: {data.get('text', 'Transcription failed')}"; stt_class = 'upload-output success'
|
50 |
+
except requests.exceptions.RequestException as e: stt_output_text = f"STT Error: {e}"; stt_class = 'upload-output error'
|
51 |
+
|
52 |
+
if image_contents:
|
53 |
+
try: content_type, content_string = image_contents.split(','); decoded_image = base64.b64decode(content_string); image_io = io.BytesIO(decoded_image); files = {'image': (image_filename, image_io, content_type)}; response = requests.post("http://127.0.0.1:7860/api/v1/image_to_3d", files=files); response.raise_for_status(); video_display = "3D Model Feature Extracted (Check Backend Logs for Output)."; video_class = 'av-video-output animate__animated animate__fadeInUp success'
|
54 |
+
except requests.exceptions.RequestException as e: video_display = f"3D Error: {e}"; video_class = 'av-video-output animate__animated animate__fadeInUp error'
|
55 |
+
|
56 |
+
video_output_component = html.Div(video_display) if video_display else ""; return stt_output_text, video_output_component, stt_class, video_class
|
57 |
+
|
58 |
+
@app.callback(Output('image-output-display', 'children'), Output('image-output-display', 'className'), Input('generate-image-button', 'n_clicks'), State('imagegen-text-input', 'value'), prevent_initial_call=True)
|
59 |
+
def generate_image_dash(n_clicks, prompt):
|
60 |
+
if not prompt: return "Please enter a prompt for image generation.", 'imagegen-output animate__animated animate__fadeInUp error-output'
|
61 |
+
api_url = "/api/v1/imagegen"; payload = {"prompt": prompt}
|
62 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); image_base64 = base64.b64encode(response.content).decode('utf-8'); return html.Img(src=f'data:image/png;base64,{image_base64}', className='generated-image'), 'imagegen-output animate__animated animate__fadeInUp success-output'
|
63 |
+
except requests.exceptions.RequestException as e: return f"Image Generation Error: {e}", 'imagegen-output animate__animated animate__fadeInUp error-output'
|
64 |
+
|
65 |
+
@app.callback(Output('3d-model-output', 'children'), Output('3d-model-output', 'className'), Input('upload-image-3d', 'contents'), State('upload-image-3d', 'filename'), prevent_initial_call=True)
|
66 |
+
def process_image_to_3d_dash(contents, filename):
|
67 |
+
if contents is None: raise PreventUpdate
|
68 |
+
try:
|
69 |
+
content_type, content_string = contents.split(',')
|
70 |
+
decoded_image = base64.b64decode(content_string); image_io = io.BytesIO(decoded_image); files = {'image': (filename, image_io, content_type)}
|
71 |
+
response = requests.post("http://127.0.0.1:7860/api/v1/image_to_3d", files=files); response.raise_for_status()
|
72 |
+
content_disposition = response.headers.get('Content-Disposition'); download_filename = 'model_3d.obj'
|
73 |
+
if content_disposition: filenames = re.findall('filename="([^"]+)"', content_disposition);
|
74 |
+
if filenames: download_filename = filenames[0]
|
75 |
+
model_base64 = base64.b64encode(response.content).decode('utf-8'); href = f'data:model/obj;base64,{model_base64}'
|
76 |
+
download_link = html.A('Download 3D Model', href=href, download=download_filename, className='download-link'),; return download_link, 'imagetod-output animate__animated animate__fadeInUp success-output'
|
77 |
+
except requests.exceptions.RequestException as e: return f"3D Conversion Error: {e}", 'imagetod-output animate__animated animate__fadeInUp error-output'
|
78 |
+
|
79 |
+
@app.callback(Output('video-gen-output', 'children'), Output('video-gen-output', 'className'), Input('generate-video-button', 'n_clicks'), State('text-video-input', 'value'), prevent_initial_call=True)
|
80 |
+
def generate_video_dash(n_clicks, prompt):
|
81 |
+
if not prompt: return "Please enter a prompt for video generation.", 'textvideo-output animate__animated animate__fadeInUp error-output'
|
82 |
+
api_url = "/api/v1/text_to_video"; payload = {"prompt": prompt}
|
83 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); video_base64 = base64.b64encode(response.content).decode('utf-8'); return html.Video(src=f'data:video/mp4;base64,{video_base64}', controls=True, className='generated-video'), 'textvideo-output animate__animated animate__fadeInUp success-output'
|
84 |
+
except requests.exceptions.RequestException as e: return f"Video Generation Error: {e}", 'textvideo-output animate__animated animate__fadeInUp error-output'
|
85 |
+
|
86 |
+
@app.callback(Output('music-output-display', 'children'), Output('music-output-display', 'className'), Input('generate-music-button', 'n_clicks'), State('musicgen-text-input', 'value'), prevent_initial_call=True)
|
87 |
+
def generate_music_dash(n_clicks, prompt):
|
88 |
+
if not prompt: return "Please enter a prompt for music generation.", 'musicgen-output animate__animated animate__fadeInUp error-output'
|
89 |
+
api_url = "/api/v1/musicgen"; payload = {"prompt": prompt}
|
90 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); audio_base64 = base64.b64encode(response.content).decode('utf-8'); return html.Audio(src=f'data:audio/wav;base64,{audio_base64}', controls=True, className='generated-audio'), 'musicgen-output animate__animated animate__fadeInUp success-output'
|
91 |
+
except requests.exceptions.RequestException as e: return f"Music Generation Error: {e}", 'musicgen-output animate__animated animate__fadeInUp error-output'
|
92 |
+
|
93 |
+
@app.callback(Output('sentiment-output-display', 'children'), Output('sentiment-output-display', 'className'), Input('analyze-sentiment-button', 'n_clicks'), State('sentiment-text-input', 'value'), prevent_initial_call=True)
|
94 |
+
def analyze_sentiment_dash(n_clicks, text):
|
95 |
+
if not text: return "Please enter text for sentiment analysis.", 'sentiment-output animate__animated animate__fadeInUp error-output'
|
96 |
+
api_url = "/api/v1/sentiment"; payload = {"text": text}
|
97 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); data = response.json(); sentiment_label = data.get('sentiment', 'Analysis Failed'); return f"Sentiment: {sentiment_label}", 'sentiment-output animate__animated animate__fadeInUp success-output'
|
98 |
+
except requests.exceptions.RequestException as e: return f"Sentiment Analysis Error: {e}", 'sentiment-output animate__animated animate__fadeInUp error-output'
|
99 |
+
|
100 |
+
@app.callback(Output('translation-output-display', 'children'), Output('translation-output-display', 'className'), Input('translate-button', 'n_clicks'), State('translate-text-input', 'value'), State('target-language-dropdown', 'value'), prevent_initial_call=True)
|
101 |
+
def translate_text_dash(n_clicks, text, target_lang):
|
102 |
+
if not text: return "Please enter text for translation.", 'translation-output animate__animated animate__fadeInUp error-output'
|
103 |
+
api_url = "/api/v1/translation"; payload = {"text": text, "target_lang": target_lang}
|
104 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); data = response.json(); translation = data.get('translated_text', 'Translation Failed'); return f"Translation ({target_lang.upper()}): {translation}", 'translation-output animate__animated animate__fadeInUp success-output'
|
105 |
+
except requests.exceptions.RequestException as e: return f"Translation Error: {e}", 'translation-output animate__animated animate__fadeInUp error-output'
|
106 |
+
|
107 |
+
@app.callback(Output('codegen-output-display', 'children'), Output('codegen-output-display', 'className'), Input('generate-code-button', 'n_clicks'), State('codegen-text-input', 'value'), prevent_initial_call=True)
|
108 |
+
def generate_code_dash(n_clicks, prompt):
|
109 |
+
if not prompt: return "Please enter a prompt for code generation.", 'codegen-output animate__animated animate__fadeInUp error-output'
|
110 |
+
api_url = "/api/v1/codegen"; payload = {"prompt": prompt}
|
111 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status()
|
112 |
+
content_disposition = response.headers.get('Content-Disposition'); download_filename = 'code.py'
|
113 |
+
if content_disposition: filenames = re.findall('filename="([^"]+)"', content_disposition); if filenames: download_filename = filenames[0]
|
114 |
+
code_base64 = base64.b64encode(response.content).decode('utf-8'); download_link = html.A('Download Code', href=f'data:text/x-python;base64,{code_base64}', download=download_filename, className='download-link'); return download_link, 'codegen-output animate__animated animate__fadeInUp success-output'
|
115 |
+
except requests.exceptions.RequestException as e: return f"Code Generation Error: {e}", 'codegen-output animate__animated animate__fadeInUp error-output'
|
116 |
+
|
117 |
+
@app.callback(Output('summarization-output-display', 'children'), Output('summarization-output-display', 'className'), Input('summarize-button', 'n_clicks'), State('summarize-text-input', 'value'), prevent_initial_call=True)
|
118 |
+
def summarize_text_dash(n_clicks, text):
|
119 |
+
if not text: return "Please enter text for summarization.", 'summarization-output animate__animated animate__fadeInUp error-output'
|
120 |
+
api_url = "/api/v1/summarization"; payload = {"text": text}
|
121 |
+
try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status()
|
122 |
+
content_disposition = response.headers.get('Content-Disposition'); download_filename = 'summary.txt'
|
123 |
+
if content_disposition: filenames = re.findall('filename="([^"]+)"', content_disposition); if filenames: download_filename = filenames[0]
|
124 |
+
summary_base64 = base64.b64encode(response.content).decode('utf-8'); download_link = html.A('Download Summary', href=f'data:text/plain;base64,{summary_base64}', download=download_filename, className='download-link'); return download_link, 'summarization-output animate__animated animate__fadeInUp success-output'
|
125 |
+
except requests.exceptions.RequestException as e: return f"Summarization Error: {e}", 'summarization-output animate__animated animate__fadeInUp error-output'
|
126 |
+
|
127 |
+
@app.callback(Output('sadtalker-video-output', 'children'), Output('sadtalker-video-output', 'className'), Input('upload-sadtalker-image', 'contents'), State('upload-sadtalker-image', 'filename'), Input('upload-sadtalker-audio', 'contents'), State('upload-sadtalker-audio', 'filename'), prevent_initial_call=True)
|
128 |
+
def process_sadtalker_dash(image_contents, image_filename, audio_contents, audio_filename):
|
129 |
+
if not image_contents or not audio_contents: return "Please upload both image and audio for SadTalker.", 'sadtalker-output animate__animated animate__fadeInUp error-output'
|
130 |
+
try:
|
131 |
+
image_content_type, image_content_string = image_contents.split(','); decoded_image = base64.b64decode(image_content_string); image_io = io.BytesIO(decoded_image)
|
132 |
+
audio_content_type, audio_content_string = audio_contents.split(','); decoded_audio = base64.b64decode(audio_content_string); audio_io = io.BytesIO(decoded_audio)
|
133 |
+
files = {'source_image_file': (image_filename, image_io, image_content_type), 'driven_audio_file': (audio_filename, audio_io, audio_content_type)}
|
134 |
+
response = requests.post("http://127.0.0.1:7860/api/v1/sadtalker", files=files); response.raise_for_status(); data = response.json(); video_url = data.get('video_url')
|
135 |
+
if video_url: video_base64 = base64.b64encode(requests.get(video_url).content).decode('utf-8'); return html.Video(src=f'data:video/mp4;base64,{video_base64}', controls=True, className='generated-video'), 'sadtalker-output animate__animated animate__fadeInUp success-output'
|
136 |
+
else: return "SadTalker video generation failed, check backend logs.", 'sadtalker-output animate__animated animate__fadeInUp error-output'
|
137 |
+
except requests.exceptions.RequestException as e: return f"SadTalker Error: {e}", 'sadtalker-output animate__animated animate__fadeInUp error-output'
|
138 |
+
|
139 |
+
if __name__ == '__main__': app.run_server(host='0.0.0.0', port=7861, debug=False)
|
image_to_3d_api.py
CHANGED
@@ -1,31 +1,19 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
3 |
from flask import jsonify, send_file, request
|
4 |
from main import *
|
5 |
from PIL import Image
|
6 |
-
import torch
|
7 |
-
import numpy as np
|
8 |
|
9 |
def image_to_3d_func(image_path, output_path="output_3d.obj"):
|
10 |
-
if image_to_3d_model is None:
|
11 |
-
|
12 |
-
|
13 |
-
image = torch.tensor(np.array(pil_image)).float().permute(2,0,1).unsqueeze(0) / 255.0
|
14 |
-
image = image.to(device)
|
15 |
-
with torch.no_grad():
|
16 |
-
mesh_obj = image_to_3d_model(image)
|
17 |
-
with open(output_path, 'w') as f:
|
18 |
-
f.write(mesh_obj)
|
19 |
-
return output_path
|
20 |
|
21 |
-
def image_to_3d_api():
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
temp_image_path = f"temp_image_{uuid.uuid4()}.png"
|
26 |
-
image_file.save(temp_image_path)
|
27 |
-
output_file = image_to_3d_func(temp_image_path)
|
28 |
-
os.remove(temp_image_path)
|
29 |
-
if output_file == "Image-to-3D model not initialized.":
|
30 |
-
return jsonify({"error": "Image to 3D failed"}), 500
|
31 |
-
return send_file(output_file, mimetype="model/obj", as_attachment=True, download_name="output_3d.obj")
|
|
|
1 |
+
```
|
2 |
+
|
3 |
+
```python
|
4 |
+
--- START OF FILE image_to_3d_api.py ---
|
5 |
+
import os, uuid
|
6 |
from flask import jsonify, send_file, request
|
7 |
from main import *
|
8 |
from PIL import Image
|
9 |
+
import torch, numpy as np, io, base64
|
|
|
10 |
|
11 |
def image_to_3d_func(image_path, output_path="output_3d.obj"):
|
12 |
+
if image_to_3d_model is None: return {"error": "Image-to-3D model not initialized."}
|
13 |
+
pil_image = Image.open(image_path).convert("RGB"); image = torch.tensor(np.array(pil_image)).float().permute(2,0,1).unsqueeze(0) / 255.0; image = image.to(device)
|
14 |
+
with torch.no_grad(): mesh_obj = image_to_3d_model(image); return {"model_3d": mesh_obj}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
def image_to_3d_api(image_path):
|
17 |
+
output = image_to_3d_func(image_path)
|
18 |
+
if "error" in output: return {"error": output["error"]}
|
19 |
+
model_3d_base64 = base64.b64encode(output['model_3d'].encode('utf-8')).decode('utf-8'); return {"model_3d_base64": model_3d_base64, "mimetype": "model/obj", "filename": "output_3d.obj"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
imagegen_api.py
CHANGED
@@ -3,31 +3,15 @@ from flask import jsonify, send_file, request
|
|
3 |
from io import BytesIO
|
4 |
from PIL import Image
|
5 |
from main import *
|
6 |
-
import torch
|
7 |
|
8 |
def generate_image(prompt, output_path="output_image.png"):
|
9 |
-
if imagegen_model is None:
|
10 |
-
return "Image generation model not initialized."
|
11 |
-
|
12 |
generator = torch.Generator(device=device).manual_seed(0)
|
13 |
-
with torch.no_grad():
|
14 |
-
image = imagegen_model(
|
15 |
-
prompt,
|
16 |
-
generator=generator,
|
17 |
-
).images[0]
|
18 |
-
image.save(output_path)
|
19 |
-
return output_path
|
20 |
|
21 |
-
def imagegen_api():
|
22 |
-
data = request.get_json()
|
23 |
-
prompt = data.get('prompt')
|
24 |
-
if not prompt:
|
25 |
-
return jsonify({"error": "Prompt is required"}), 400
|
26 |
output_file = generate_image(prompt)
|
27 |
-
if output_file
|
28 |
-
|
29 |
-
|
30 |
-
pil_image = Image.open(output_file)
|
31 |
-
pil_image.save(image_io, 'PNG')
|
32 |
-
image_io.seek(0)
|
33 |
-
return send_file(image_io, mimetype='image/png', as_attachment=True, download_name="output.png")
|
|
|
3 |
from io import BytesIO
|
4 |
from PIL import Image
|
5 |
from main import *
|
6 |
+
import torch, base64
|
7 |
|
8 |
def generate_image(prompt, output_path="output_image.png"):
|
9 |
+
if imagegen_model is None: return {"error": "Image generation model not initialized."}
|
|
|
|
|
10 |
generator = torch.Generator(device=device).manual_seed(0)
|
11 |
+
with torch.no_grad(): image = imagegen_model(prompt, generator=generator,).images[0]; image.save(output_path); return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
def imagegen_api(prompt):
|
|
|
|
|
|
|
|
|
14 |
output_file = generate_image(prompt)
|
15 |
+
if isinstance(output_file, dict) and "error" in output_file: return {"error": output_file["error"]}
|
16 |
+
image_io = BytesIO(); pil_image = Image.open(output_file); pil_image.save(image_io, 'PNG'); image_base64 = base64.b64encode(image_io.getvalue()).decode('utf-8')
|
17 |
+
os.remove(output_file); return {"image_base64": image_base64, "mimetype": "image/png"}
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -1,10 +1,4 @@
|
|
1 |
-
import threading
|
2 |
-
import queue
|
3 |
-
import time
|
4 |
-
import os
|
5 |
-
import nltk
|
6 |
-
import re
|
7 |
-
import json
|
8 |
from flask import Flask
|
9 |
from flask_cors import CORS
|
10 |
from api import *
|
@@ -19,50 +13,13 @@ from background_tasks import *
|
|
19 |
from text_generation import *
|
20 |
from sadtalker_utils import *
|
21 |
|
22 |
-
|
23 |
-
state_dict = None
|
24 |
-
enc = None
|
25 |
-
config = None
|
26 |
-
model_gpt2 = None
|
27 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
-
news_clf = None
|
29 |
-
tfidf_vectorizer = None
|
30 |
-
text_queue = queue.Queue()
|
31 |
-
categories = None
|
32 |
-
background_threads = []
|
33 |
-
feedback_queue = queue.Queue()
|
34 |
-
reasoning_queue = queue.Queue()
|
35 |
-
seen_responses = set()
|
36 |
-
dialogue_history = []
|
37 |
-
vocabulary = set()
|
38 |
-
word_to_index = {}
|
39 |
-
index_to_word = []
|
40 |
-
translation_model = None
|
41 |
-
sp = None
|
42 |
-
codegen_model = None
|
43 |
-
codegen_tokenizer = None
|
44 |
-
codegen_vocabulary = None
|
45 |
-
codegen_index_to_word = None
|
46 |
-
codegen_word_to_index = None
|
47 |
-
summarization_model = None
|
48 |
-
summarization_vocabulary = set()
|
49 |
-
summarization_word_to_index = {}
|
50 |
-
summarization_index_to_word = []
|
51 |
-
sadtalker_instance = None
|
52 |
-
imagegen_model = None
|
53 |
-
image_to_3d_model = None
|
54 |
-
text_to_video_model = None
|
55 |
-
stream_type = "text"
|
56 |
-
sentiment_model = None
|
57 |
-
stt_model = None
|
58 |
-
tts_model = None
|
59 |
-
musicgen_model = None
|
60 |
|
61 |
def load_models():
|
62 |
-
global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer,
|
63 |
model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
|
64 |
translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
|
65 |
-
codegen_model, codegen_tokenizer,
|
66 |
summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
|
67 |
imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
|
68 |
image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
|
@@ -71,6 +28,7 @@ def load_models():
|
|
71 |
stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
|
72 |
tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
|
73 |
musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
|
|
|
74 |
sadtalker_instance = SadTalker(checkpoint_path='./checkpoints', config_path='./src/config')
|
75 |
|
76 |
if __name__ == "__main__":
|
@@ -78,13 +36,8 @@ if __name__ == "__main__":
|
|
78 |
load_models()
|
79 |
categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
|
80 |
import background_tasks
|
81 |
-
background_tasks.categories = categories
|
82 |
-
|
83 |
-
|
84 |
-
background_threads.
|
85 |
-
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
|
86 |
-
background_threads.append(threading.Thread(target=background_training, daemon=True))
|
87 |
-
background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
|
88 |
-
for thread in background_threads:
|
89 |
-
thread.start()
|
90 |
app.run(host='0.0.0.0', port=7860)
|
|
|
1 |
+
import threading, queue, time, os, nltk, re, json
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from flask import Flask
|
3 |
from flask_cors import CORS
|
4 |
from api import *
|
|
|
13 |
from text_generation import *
|
14 |
from sadtalker_utils import *
|
15 |
|
16 |
+
state_dict, enc, config, model_gpt2, device, news_clf, tfidf_vectorizer, text_queue, categories, background_threads, feedback_queue, reasoning_queue, seen_responses, dialogue_history, vocabulary, word_to_index, index_to_word, translation_model, sp, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, summarization_vocabulary, summarization_word_to_index, summarization_index_to_word, sadtalker_instance, imagegen_model, image_to_3d_model, text_to_video_model, stream_type, sentiment_model, stt_model, tts_model, musicgen_model, xtts_model = None, None, None, None, torch.device("cuda" if torch.cuda.is_available() else "cpu"), None, None, queue.Queue(), None, [], queue.Queue(), queue.Queue(), set(), [], set(), {}, [], None, None, None, None, None, None, set(), {}, [], None, None, None, None, "text", None, None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def load_models():
|
19 |
+
global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, xtts_model
|
20 |
model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
|
21 |
translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
|
22 |
+
codegen_model, codegen_tokenizer, _, _, _ = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
|
23 |
summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
|
24 |
imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
|
25 |
image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
|
|
|
28 |
stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
|
29 |
tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
|
30 |
musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
|
31 |
+
xtts_model = initialize_xtts_model(XTTS_FOLDER, XTTS_FILES_URLS)
|
32 |
sadtalker_instance = SadTalker(checkpoint_path='./checkpoints', config_path='./src/config')
|
33 |
|
34 |
if __name__ == "__main__":
|
|
|
36 |
load_models()
|
37 |
categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
|
38 |
import background_tasks
|
39 |
+
background_tasks.categories = categories; background_tasks.text_queue = text_queue; background_tasks.reasoning_queue = reasoning_queue
|
40 |
+
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True)); background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
|
41 |
+
background_threads.append(threading.Thread(target=background_training, daemon=True)); background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
|
42 |
+
for thread in background_threads: thread.start()
|
|
|
|
|
|
|
|
|
|
|
43 |
app.run(host='0.0.0.0', port=7860)
|
model_loader.py
CHANGED
@@ -1,725 +1,229 @@
|
|
1 |
-
from tokenxxx import *
|
2 |
-
from constants import *
|
3 |
-
from utils import *
|
4 |
-
import os
|
5 |
-
import
|
6 |
-
import
|
7 |
-
import
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
def
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
def
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
def
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
def
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
def
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
class
|
151 |
-
def
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
def
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
return
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
self.
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
class CodeGenForCausalLM(nn.Module):
|
232 |
-
def __init__(self, config):
|
233 |
-
super().__init__()
|
234 |
-
d_model = getattr(config, "d_model", 1024)
|
235 |
-
n_head = getattr(config, "n_head", 16)
|
236 |
-
num_layers = getattr(config, "num_layers", 12)
|
237 |
-
dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
|
238 |
-
self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
|
239 |
-
self.lm_head = nn.Linear(d_model, config.vocab_size)
|
240 |
-
def forward(self, tgt, memory=None):
|
241 |
-
if memory is None:
|
242 |
-
memory = torch.zeros_like(tgt)
|
243 |
-
return self.lm_head(self.transformer_decoder(tgt, memory))
|
244 |
-
|
245 |
-
class BartForConditionalGeneration(nn.Module):
|
246 |
-
def __init__(self, config):
|
247 |
-
super().__init__()
|
248 |
-
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
|
249 |
-
self.encoder = nn.TransformerEncoder(layer, num_layers=6)
|
250 |
-
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
|
251 |
-
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
|
252 |
-
self.output_layer = nn.Linear(768, config.vocab_size)
|
253 |
-
def forward(self, src, tgt):
|
254 |
-
return self.output_layer(self.decoder(tgt, self.encoder(src)))
|
255 |
-
|
256 |
-
class ResnetBlock(nn.Module):
|
257 |
-
def __init__(self, in_ch, out_ch):
|
258 |
-
super().__init__()
|
259 |
-
self.norm1 = nn.GroupNorm(32, in_ch)
|
260 |
-
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
261 |
-
self.norm2 = nn.GroupNorm(32, out_ch)
|
262 |
-
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
263 |
-
self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
|
264 |
-
def forward(self, x):
|
265 |
-
sc = self.conv_shortcut(x)
|
266 |
-
h = F.silu(self.norm1(x))
|
267 |
-
h = self.conv1(h)
|
268 |
-
h = F.silu(self.norm2(x))
|
269 |
-
h = self.conv2(h)
|
270 |
-
return h + sc
|
271 |
-
|
272 |
-
class Downsample(nn.Module):
|
273 |
-
def __init__(self, in_ch, out_ch):
|
274 |
-
super().__init__()
|
275 |
-
self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
|
276 |
-
def forward(self, x):
|
277 |
-
return self.conv(x)
|
278 |
-
|
279 |
-
class DownBlock(nn.Module):
|
280 |
-
def __init__(self, in_ch, out_ch, num_res):
|
281 |
-
super().__init__()
|
282 |
-
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
|
283 |
-
self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
|
284 |
-
def forward(self, x):
|
285 |
-
for r in self.resnets:
|
286 |
-
x = r(x)
|
287 |
-
for ds in self.downsamplers:
|
288 |
-
x = ds(x)
|
289 |
-
return x
|
290 |
-
|
291 |
-
class Upsample(nn.Module):
|
292 |
-
def __init__(self, in_ch, out_ch):
|
293 |
-
super().__init__()
|
294 |
-
self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
|
295 |
-
def forward(self, x):
|
296 |
-
return self.conv(x)
|
297 |
-
|
298 |
-
class UpBlock(nn.Module):
|
299 |
-
def __init__(self, in_ch, out_ch, num_res):
|
300 |
-
super().__init__()
|
301 |
-
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
|
302 |
-
self.upsampler = Upsample(out_ch, out_ch)
|
303 |
-
def forward(self, x):
|
304 |
-
for r in self.resnets:
|
305 |
-
x = r(x)
|
306 |
-
return self.upsampler(x)
|
307 |
-
|
308 |
-
class AttentionBlock(nn.Module):
|
309 |
-
def __init__(self, ch):
|
310 |
-
super().__init__()
|
311 |
-
self.norm = nn.GroupNorm(32, ch)
|
312 |
-
self.query = nn.Conv2d(ch, ch, 1)
|
313 |
-
self.key = nn.Conv2d(ch, ch, 1)
|
314 |
-
self.value = nn.Conv2d(ch, ch, 1)
|
315 |
-
self.proj_attn = nn.Conv2d(ch, ch, 1)
|
316 |
-
def forward(self, x):
|
317 |
-
b, c, h, w = x.shape
|
318 |
-
xn = self.norm(x)
|
319 |
-
q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
|
320 |
-
k = self.key(xn).view(b, c, -1)
|
321 |
-
v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
|
322 |
-
attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
|
323 |
-
out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
|
324 |
-
return x + self.proj_attn(out)
|
325 |
-
|
326 |
-
class Encoder(nn.Module):
|
327 |
-
def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
|
328 |
-
super().__init__()
|
329 |
-
self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
|
330 |
-
self.down_blocks = nn.ModuleList([
|
331 |
-
DownBlock(base_ch, base_ch, 2),
|
332 |
-
DownBlock(base_ch, base_ch * 2, 2),
|
333 |
-
DownBlock(base_ch * 2, base_ch * 4, 2),
|
334 |
-
DownBlock(base_ch * 4, base_ch * 4, 2)
|
335 |
-
])
|
336 |
-
self.mid_block = nn.ModuleList([
|
337 |
-
ResnetBlock(base_ch * 4, base_ch * 4),
|
338 |
-
AttentionBlock(base_ch * 4),
|
339 |
-
ResnetBlock(base_ch * 4, base_ch * 4)
|
340 |
-
])
|
341 |
-
self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
|
342 |
-
self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
|
343 |
-
self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
|
344 |
-
def forward(self, x):
|
345 |
-
x = self.conv_in(x)
|
346 |
-
for blk in self.down_blocks:
|
347 |
-
x = blk(x)
|
348 |
-
for m in self.mid_block:
|
349 |
-
x = m(x)
|
350 |
-
x = self.conv_norm_out(x)
|
351 |
-
x = self.conv_out(x)
|
352 |
-
return self.quant_conv(x)
|
353 |
-
|
354 |
-
class Decoder(nn.Module):
|
355 |
-
def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
|
356 |
-
super().__init__()
|
357 |
-
self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
|
358 |
-
self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
|
359 |
-
self.mid_block = nn.ModuleList([
|
360 |
-
ResnetBlock(base_ch * 4, base_ch * 4),
|
361 |
-
AttentionBlock(base_ch * 4),
|
362 |
-
ResnetBlock(base_ch * 4, base_ch * 4)
|
363 |
-
])
|
364 |
-
self.up_blocks = nn.ModuleList([
|
365 |
-
UpBlock(base_ch * 4, base_ch * 4, 3),
|
366 |
-
UpBlock(base_ch * 4, base_ch * 2, 3),
|
367 |
-
UpBlock(base_ch * 2, base_ch, 3),
|
368 |
-
UpBlock(base_ch, base_ch, 3)
|
369 |
-
])
|
370 |
-
self.conv_norm_out = nn.GroupNorm(32, base_ch)
|
371 |
-
self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
|
372 |
-
def forward(self, x):
|
373 |
-
x = self.post_quant_conv(x)
|
374 |
-
x = self.conv_in(x)
|
375 |
-
for m in self.mid_block:
|
376 |
-
x = m(x)
|
377 |
-
for up in self.up_blocks:
|
378 |
-
x = up(x)
|
379 |
-
x = self.conv_norm_out(x)
|
380 |
-
return self.conv_out(x)
|
381 |
-
|
382 |
-
class AutoencoderKL(nn.Module):
|
383 |
-
def __init__(self, config):
|
384 |
-
super().__init__()
|
385 |
-
in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
|
386 |
-
out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
|
387 |
-
base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
|
388 |
-
latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
|
389 |
-
self.encoder = Encoder(in_ch, base_ch, latent_ch)
|
390 |
-
self.decoder = Decoder(out_ch, base_ch, latent_ch)
|
391 |
-
def forward(self, x):
|
392 |
-
return self.decoder(self.encoder(x))
|
393 |
-
def decode(self, x):
|
394 |
-
return self.decoder(x)
|
395 |
-
|
396 |
-
class TransformerBlock(nn.Module):
|
397 |
-
def __init__(self, embed_dim, num_heads):
|
398 |
-
super().__init__()
|
399 |
-
self.norm1 = nn.LayerNorm(embed_dim)
|
400 |
-
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
401 |
-
self.norm2 = nn.LayerNorm(embed_dim)
|
402 |
-
hidden_dim = embed_dim * 4
|
403 |
-
self.mlp = nn.Sequential(
|
404 |
-
nn.Linear(embed_dim, hidden_dim),
|
405 |
-
nn.GELU(),
|
406 |
-
nn.Linear(hidden_dim, embed_dim)
|
407 |
-
)
|
408 |
-
def forward(self, x):
|
409 |
-
res = x
|
410 |
-
x = self.norm1(x)
|
411 |
-
x = x.transpose(0, 1)
|
412 |
-
attn, _ = self.attn(x, x, x)
|
413 |
-
x = attn.transpose(0, 1)
|
414 |
-
x = res + x
|
415 |
-
return x + self.mlp(self.norm2(x))
|
416 |
-
|
417 |
-
class VisionTransformer(nn.Module):
|
418 |
-
def __init__(self, config):
|
419 |
-
super().__init__()
|
420 |
-
if isinstance(config, dict):
|
421 |
-
self.img_size = config.get("img_size", 592)
|
422 |
-
self.patch_size = config.get("patch_size", 16)
|
423 |
-
self.embed_dim = config.get("hidden_size", 768)
|
424 |
-
depth = config.get("depth", 12)
|
425 |
-
num_heads = config.get("num_heads", 12)
|
426 |
-
else:
|
427 |
-
self.img_size = config.__dict__.get("img_size", 592)
|
428 |
-
self.patch_size = config.__dict__.get("patch_size", 16)
|
429 |
-
self.embed_dim = config.__dict__.get("hidden_size", 768)
|
430 |
-
depth = config.__dict__.get("depth", 12)
|
431 |
-
num_heads = config.__dict__.get("num_heads", 12)
|
432 |
-
num_patches = (self.img_size // self.patch_size) ** 2
|
433 |
-
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
434 |
-
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
|
435 |
-
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
|
436 |
-
self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
|
437 |
-
self.norm = nn.LayerNorm(self.embed_dim)
|
438 |
-
self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
|
439 |
-
self._init_weights()
|
440 |
-
def _init_weights(self):
|
441 |
-
nn.init.normal_(self.cls_token, std=0.02)
|
442 |
-
nn.init.normal_(self.pos_embed, std=0.02)
|
443 |
-
def forward(self, x):
|
444 |
-
x = self.patch_embed(x)
|
445 |
-
x = x.flatten(2).transpose(1, 2)
|
446 |
-
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
447 |
-
x = torch.cat((cls_tokens, x), dim=1)
|
448 |
-
x = x + self.pos_embed
|
449 |
-
for blk in self.blocks:
|
450 |
-
x = blk(x)
|
451 |
-
return self.norm(x)[:, 0]
|
452 |
-
|
453 |
-
class OpenLRM(nn.Module):
|
454 |
-
def __init__(self, config):
|
455 |
-
super().__init__()
|
456 |
-
self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
|
457 |
-
hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
|
458 |
-
self.linear = nn.Linear(hidden, hidden)
|
459 |
-
def forward(self, x):
|
460 |
-
return self.linear(self.encoder["model"](x))
|
461 |
-
|
462 |
-
class VideoUNet(nn.Module):
|
463 |
-
def __init__(self, in_ch=4, out_ch=4, features=None):
|
464 |
-
super().__init__()
|
465 |
-
if features is None:
|
466 |
-
features = [64, 128, 256]
|
467 |
-
self.encoder = nn.ModuleList()
|
468 |
-
self.pool = nn.MaxPool3d(2, 2)
|
469 |
-
self.decoder = nn.ModuleList()
|
470 |
-
for f in features:
|
471 |
-
self.encoder.append(nn.Sequential(
|
472 |
-
nn.Conv3d(in_ch, f, 3, padding=1),
|
473 |
-
nn.ReLU(inplace=True),
|
474 |
-
nn.Conv3d(f, f, 3, padding=1),
|
475 |
-
nn.ReLU(inplace=True)
|
476 |
-
))
|
477 |
-
in_ch = f
|
478 |
-
for f in reversed(features):
|
479 |
-
self.decoder.append(nn.Sequential(
|
480 |
-
nn.Conv3d(f * 2, f, 3, padding=1),
|
481 |
-
nn.ReLU(inplace=True),
|
482 |
-
nn.Conv3d(f, f, 3, padding=1),
|
483 |
-
nn.ReLU(inplace=True)
|
484 |
-
))
|
485 |
-
self.final_conv = nn.Conv3d(features[0], out_ch, 1)
|
486 |
-
def forward(self, x, t, encoder_hidden_states):
|
487 |
-
skips = []
|
488 |
-
for enc in self.encoder:
|
489 |
-
x = enc(x)
|
490 |
-
skips.append(x)
|
491 |
-
x = self.pool(x)
|
492 |
-
for dec in self.decoder:
|
493 |
-
skip = skips.pop()
|
494 |
-
x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
|
495 |
-
x = torch.cat([x, skip], dim=1)
|
496 |
-
x = dec(x)
|
497 |
-
return self.final_conv(x)
|
498 |
-
|
499 |
-
class SentimentClassifierModel(nn.Module):
|
500 |
-
def __init__(self, config):
|
501 |
-
super().__init__()
|
502 |
-
self.classifier = nn.Sequential(
|
503 |
-
nn.Linear(768, 256),
|
504 |
-
nn.ReLU(),
|
505 |
-
nn.Linear(256, 2)
|
506 |
-
)
|
507 |
-
def forward(self, x):
|
508 |
-
return self.classifier(x)
|
509 |
-
|
510 |
-
class STTModel(nn.Module):
|
511 |
-
def __init__(self, config):
|
512 |
-
super().__init__()
|
513 |
-
self.net = nn.Sequential(
|
514 |
-
nn.Linear(768, 512),
|
515 |
-
nn.ReLU(),
|
516 |
-
nn.Linear(512, 768)
|
517 |
-
)
|
518 |
-
def forward(self, x):
|
519 |
-
return self.net(x)
|
520 |
-
|
521 |
-
class TTSModel(nn.Module):
|
522 |
-
def __init__(self, config):
|
523 |
-
super().__init__()
|
524 |
-
self.net = nn.Sequential(
|
525 |
-
nn.Linear(768, 512),
|
526 |
-
nn.ReLU(),
|
527 |
-
nn.Linear(512, 768)
|
528 |
-
)
|
529 |
-
def forward(self, x):
|
530 |
-
return self.net(x)
|
531 |
-
|
532 |
-
class MusicGenModel(nn.Module):
|
533 |
-
def __init__(self, config):
|
534 |
-
super().__init__()
|
535 |
-
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
|
536 |
-
self.transformer = nn.TransformerEncoder(layer, num_layers=12)
|
537 |
-
self.linear = nn.Linear(768, 768)
|
538 |
-
def forward(self, x):
|
539 |
-
return self.linear(self.transformer(x))
|
540 |
-
|
541 |
-
class SimpleTextEncoder(nn.Module):
|
542 |
-
def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
|
543 |
-
super().__init__()
|
544 |
-
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
545 |
-
self.max_length = max_length
|
546 |
-
def forward(self, text_tokens):
|
547 |
-
return self.embedding(text_tokens)
|
548 |
-
|
549 |
-
class DiffusionScheduler:
|
550 |
-
def __init__(self, steps):
|
551 |
-
self.steps = steps
|
552 |
-
self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
|
553 |
-
self.alphas = 1 - self.betas
|
554 |
-
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
|
555 |
-
def step(self, noise, t, sample):
|
556 |
-
alpha_bar = self.alpha_bars[t]
|
557 |
-
alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device)
|
558 |
-
x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar)
|
559 |
-
new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise
|
560 |
-
return new_sample
|
561 |
-
|
562 |
-
class VideoOutput:
|
563 |
-
def __init__(self, frames):
|
564 |
-
self.frames = [img_as_ubyte(frame) for frame in frames[0]]
|
565 |
-
|
566 |
-
class VideoPipeline(nn.Module):
|
567 |
-
def __init__(self, unet, vae, text_encoder, vocab):
|
568 |
-
super().__init__()
|
569 |
-
self.unet = unet
|
570 |
-
self.vae = vae
|
571 |
-
self.text_encoder = text_encoder
|
572 |
-
self.vocab = vocab
|
573 |
-
def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
|
574 |
-
token_ids = simple_tokenizer(prompt, self.vocab)
|
575 |
-
text_emb = self.text_encoder(token_ids)
|
576 |
-
latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
|
577 |
-
sched = DiffusionScheduler(steps)
|
578 |
-
for t in range(steps):
|
579 |
-
noise = self.unet(latent, t, text_emb)
|
580 |
-
latent = sched.step(noise, t, latent)
|
581 |
-
frames = self.vae.decode(latent / 0.18215)
|
582 |
-
frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
|
583 |
-
return VideoOutput(frames)
|
584 |
-
|
585 |
-
def initialize_gpt2_model(folder, files):
|
586 |
-
download_files(folder, files)
|
587 |
-
config = GPT2Config()
|
588 |
-
model = GPT2LMHeadModel(config).to(device)
|
589 |
-
sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
|
590 |
-
load_state_dict_safe(model, sd)
|
591 |
-
model.eval()
|
592 |
-
enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
|
593 |
-
return model, enc
|
594 |
-
|
595 |
-
def initialize_translation_model(folder, files):
|
596 |
-
download_files(folder, files)
|
597 |
-
config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
598 |
-
model = MBartForConditionalGeneration(config).to(device)
|
599 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
600 |
-
load_state_dict_safe(model, sd)
|
601 |
-
model.eval()
|
602 |
-
vp = os.path.join(folder, "vocab.json")
|
603 |
-
if os.path.exists(vp):
|
604 |
-
vocab = read_json(vp)
|
605 |
-
model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
|
606 |
-
else:
|
607 |
-
model.tokenizer = lambda txt: txt
|
608 |
-
model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
|
609 |
-
return model
|
610 |
-
|
611 |
-
def initialize_codegen_model(folder, files):
|
612 |
-
download_files(folder, files)
|
613 |
-
config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
614 |
-
model = CodeGenForCausalLM(config).to(device)
|
615 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
616 |
-
load_state_dict_safe(model, sd)
|
617 |
-
model.eval()
|
618 |
-
tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
|
619 |
-
vocab = read_json(os.path.join(folder, "vocab.json"))
|
620 |
-
idx2w = {v: k for k, v in vocab.items()}
|
621 |
-
model.tokenizer = tok
|
622 |
-
return model, tok, vocab, idx2w, vocab
|
623 |
-
|
624 |
-
def initialize_summarization_model(folder, files):
|
625 |
-
download_files(folder, files)
|
626 |
-
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
627 |
-
model = BartForConditionalGeneration(config).to(device)
|
628 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
629 |
-
load_state_dict_safe(model, sd)
|
630 |
-
model.eval()
|
631 |
-
vp = os.path.join(folder, "vocab.json")
|
632 |
-
if os.path.exists(vp):
|
633 |
-
vocab_json = read_json(vp)
|
634 |
-
vocab = set(vocab_json.keys())
|
635 |
-
return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
|
636 |
-
return model, None, None, None
|
637 |
-
|
638 |
-
def initialize_imagegen_model(folder, files):
|
639 |
-
download_files(folder, files)
|
640 |
-
config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
641 |
-
vae = AutoencoderKL(config).to(device)
|
642 |
-
sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
|
643 |
-
load_state_dict_safe(vae, sd)
|
644 |
-
vae.eval()
|
645 |
-
return vae
|
646 |
-
|
647 |
-
def initialize_image_to_3d_model(folder, files):
|
648 |
-
download_files(folder, files)
|
649 |
-
config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
650 |
-
model3d = OpenLRM(config).to(device)
|
651 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
652 |
-
load_state_dict_safe(model3d, sd)
|
653 |
-
model3d.eval()
|
654 |
-
return model3d
|
655 |
-
|
656 |
-
def initialize_text_to_video_model(folder, files):
|
657 |
-
download_files(folder, files)
|
658 |
-
unet_cfg = read_json(os.path.join(folder, "config.json"))
|
659 |
-
unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
|
660 |
-
unet = VideoUNet(**unet_cfg).half().to(device)
|
661 |
-
sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
|
662 |
-
load_state_dict_safe(unet, sd_unet)
|
663 |
-
unet.eval()
|
664 |
-
vae_cfg = read_json(os.path.join(folder, "config.json"))
|
665 |
-
vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
|
666 |
-
vae = AutoencoderKL(vae_cfg).half().to(device)
|
667 |
-
sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
|
668 |
-
load_state_dict_safe(vae, sd_vae)
|
669 |
-
vae.eval()
|
670 |
-
vp = os.path.join(folder, "vocab.json")
|
671 |
-
text_vocab = read_json(vp) if os.path.exists(vp) else {}
|
672 |
-
te_path = os.path.join(folder, "text_encoder.bin")
|
673 |
-
if os.path.exists(te_path):
|
674 |
-
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
|
675 |
-
sd_te = torch.load(te_path, map_location=device)
|
676 |
-
load_state_dict_safe(text_encoder, sd_te)
|
677 |
-
else:
|
678 |
-
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
|
679 |
-
text_encoder.eval()
|
680 |
-
return VideoPipeline(unet, vae, text_encoder, text_vocab)
|
681 |
-
|
682 |
-
def initialize_sentiment_model(folder, files):
|
683 |
-
download_files(folder, files)
|
684 |
-
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
685 |
-
model = SentimentClassifierModel(config).to(device)
|
686 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
687 |
-
load_state_dict_safe(model, sd)
|
688 |
-
model.eval()
|
689 |
-
vp = os.path.join(folder, "vocab.json")
|
690 |
-
if os.path.exists(vp):
|
691 |
-
read_json(vp)
|
692 |
-
return model
|
693 |
-
|
694 |
-
def initialize_stt_model(folder, files):
|
695 |
-
download_files(folder, files)
|
696 |
-
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
697 |
-
model = STTModel(config).to(device)
|
698 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
699 |
-
load_state_dict_safe(model, sd)
|
700 |
-
model.eval()
|
701 |
-
vp = os.path.join(folder, "vocab.json")
|
702 |
-
if os.path.exists(vp):
|
703 |
-
read_json(vp)
|
704 |
-
return model
|
705 |
-
|
706 |
-
def initialize_tts_model(folder, files):
|
707 |
-
download_files(folder, files)
|
708 |
-
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
709 |
-
model = TTSModel(config).to(device)
|
710 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
711 |
-
load_state_dict_safe(model, sd)
|
712 |
-
model.eval()
|
713 |
-
vp = os.path.join(folder, "vocab.json")
|
714 |
-
if os.path.exists(vp):
|
715 |
-
read_json(vp)
|
716 |
-
return model
|
717 |
-
|
718 |
-
def initialize_musicgen_model(folder, files):
|
719 |
-
download_files(folder, files)
|
720 |
-
config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
721 |
-
model = MusicGenModel(config).to(device)
|
722 |
-
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
723 |
-
load_state_dict_safe(model, sd)
|
724 |
-
model.eval()
|
725 |
-
return model
|
|
|
1 |
+
from tokenxxx import *
|
2 |
+
from constants import *
|
3 |
+
from utils import *
|
4 |
+
import os, json, urllib.request, urllib.parse, torch, hashlib, inspect
|
5 |
+
from tqdm import tqdm
|
6 |
+
from TTS.config import load_config
|
7 |
+
from TTS.tts.models.xtts import Xtts
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
|
10 |
+
def filter_kwargs(cls, kwargs): sig = inspect.signature(cls.__init__); accepted = set(sig.parameters.keys()) - {"self"}; return {k: v for k, v in kwargs.items() if k in accepted}
|
11 |
+
def sanitize_filename(name, url=None): for c in '<>:"/\\|?*': name = name.replace(c, ''); if not name and url is not None: name = hashlib.md5(url.encode()).hexdigest(); return name
|
12 |
+
def download_file(url, filepath): d = os.path.dirname(filepath); if d and not os.path.exists(d): os.makedirs(d, exist_ok=True)
|
13 |
+
while not os.path.exists(filepath):
|
14 |
+
try:
|
15 |
+
def prog(t): last = [0]; def inner(n, bs, ts): if ts > 0: t.total = ts; t.update(n * bs - last[0]); last[0] = n * bs; return inner
|
16 |
+
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t: urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
|
17 |
+
except: continue
|
18 |
+
def download_files(folder, files_spec):
|
19 |
+
if isinstance(files_spec, dict): for fn, url in files_spec.items(): fn = sanitize_filename(fn, url); fp = os.path.join(folder, fn); download_file(url, fp)
|
20 |
+
elif isinstance(files_spec, list):
|
21 |
+
for item in files_spec:
|
22 |
+
if isinstance(item, str): url = item; parsed = urllib.parse.urlparse(url); fn = os.path.basename(parsed.path); if not fn: fn = hashlib.md5(url.encode()).hexdigest(); fn = sanitize_filename(fn, url)
|
23 |
+
elif isinstance(item, (list, tuple)) and len(item) == 2: url, fn = item; fn = sanitize_filename(fn, url)
|
24 |
+
elif isinstance(item, dict) and "filename" in item and "url" in item: fn = sanitize_filename(item["filename"], item["url"]); url = item["url"]
|
25 |
+
else: raise ValueError("Invalid file specification")
|
26 |
+
fp = os.path.join(folder, fn); download_file(url, fp)
|
27 |
+
else: raise ValueError("files_spec must be dict or list")
|
28 |
+
|
29 |
+
def read_json(fp): with open(fp, 'r', encoding='utf-8') as f: return json.load(f)
|
30 |
+
def get_codegen_tokenizer(vocab_path, merges_path):
|
31 |
+
with open(vocab_path, 'r', encoding='utf-8') as f: vocab = json.load(f)
|
32 |
+
with open(merges_path, 'r', encoding='utf-8') as f: merges = f.read().splitlines()
|
33 |
+
merge_ranks = {};
|
34 |
+
for i, merge in enumerate(merges): parts = merge.strip().split(); if len(parts) == 2: merge_ranks[tuple(parts)] = i
|
35 |
+
def bpe(token):
|
36 |
+
word = list(token); pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
|
37 |
+
while True: candidate = None; candidate_rank = None; candidate_index = None
|
38 |
+
for i, pair in enumerate(pairs): if pair in merge_ranks: rank = merge_ranks[pair]; if candidate is None or rank < candidate_rank: candidate = pair; candidate_rank = rank; candidate_index = i
|
39 |
+
if candidate is None: break
|
40 |
+
first, second = candidate; new_word = []; i = 0
|
41 |
+
while i < len(word):
|
42 |
+
try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
|
43 |
+
except ValueError: new_word.extend(word[i:]); break
|
44 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
|
45 |
+
else: new_word.append(word[i]); i += 1
|
46 |
+
word = new_word;
|
47 |
+
if len(word) == 1: break
|
48 |
+
pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
|
49 |
+
return word
|
50 |
+
def tokenizer(text): tokens = []; for token in text.split(): bpe_tokens = bpe(token); for subtoken in bpe_tokens: tokens.append(vocab.get(subtoken, 0)); return tokens
|
51 |
+
return tokenizer
|
52 |
+
|
53 |
+
def simple_tokenizer(text, vocab, max_length=77): toks = text.split(); ids = [vocab.get(t, 1) for t in toks]; if len(ids) < max_length: ids = ids + [0] * (max_length - len(ids))
|
54 |
+
else: ids = ids[:max_length]; return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
|
55 |
+
def load_state_dict_safe(model, loaded_state_dict): model_state = model.state_dict(); new_state = {}; for key, value in model_state.items(): if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape: new_state[key] = loaded_state_dict[key]
|
56 |
+
else: new_state[key] = value; model.load_state_dict(new_state, strict=False)
|
57 |
+
|
58 |
+
class GPT2Config: def __init__(self, vocab_size=50257, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
|
59 |
+
@classmethod
|
60 |
+
def from_dict(cls, d): return cls(**d)
|
61 |
+
class MBartConfig: def __init__(self, vocab_size=50265, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
|
62 |
+
@classmethod
|
63 |
+
def from_dict(cls, d): return cls(**d)
|
64 |
+
class CodeGenConfig: def __init__(self, vocab_size=50257, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
|
65 |
+
@classmethod
|
66 |
+
def from_dict(cls, d): return cls(**d)
|
67 |
+
class BartConfig: def __init__(self, vocab_size=50265, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
|
68 |
+
@classmethod
|
69 |
+
def from_dict(cls, d): return cls(**d)
|
70 |
+
class AutoencoderKLConfig:
|
71 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
72 |
+
@classmethod
|
73 |
+
def from_dict(cls, d): return cls(**d)
|
74 |
+
class OpenLRMConfig:
|
75 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
76 |
+
@classmethod
|
77 |
+
def from_dict(cls, d): return cls(**d)
|
78 |
+
class UNet2DConditionModelConfig:
|
79 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
80 |
+
@classmethod
|
81 |
+
def from_dict(cls, d): return cls(**d)
|
82 |
+
class MusicGenConfig:
|
83 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
84 |
+
@classmethod
|
85 |
+
def from_dict(cls, d): return cls(**d)
|
86 |
+
class XTTSConfig:
|
87 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
88 |
+
@classmethod
|
89 |
+
def from_dict(cls, d): return cls(**d)
|
90 |
+
|
91 |
+
class GPT2LMHeadModel(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.transformer = nn.TransformerEncoder(layer, num_layers=12); self.lm_head = nn.Linear(768, config.vocab_size)
|
92 |
+
def forward(self, x): return self.lm_head(self.transformer(x))
|
93 |
+
class MBartForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__(); self.config = config; layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.encoder = nn.TransformerEncoder(layer, num_layers=6)
|
94 |
+
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12); self.decoder = nn.TransformerDecoder(dlayer, num_layers=6); self.output_layer = nn.Linear(768, config.vocab_size)
|
95 |
+
def forward(self, src, tgt): return self.output_layer(self.decoder(tgt, self.encoder(src)))
|
96 |
+
class CodeGenForCausalLM(nn.Module): def __init__(self, config): super().__init__(); d_model = getattr(config, "d_model", 1024); n_head = getattr(config, "n_head", 16)
|
97 |
+
num_layers = getattr(config, "num_layers", 12); dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head); self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers); self.lm_head = nn.Linear(d_model, config.vocab_size)
|
98 |
+
def forward(self, tgt, memory=None): if memory is None: memory = torch.zeros_like(tgt); return self.lm_head(self.transformer_decoder(tgt, memory))
|
99 |
+
class BartForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.encoder = nn.TransformerEncoder(layer, num_layers=6)
|
100 |
+
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12); self.decoder = nn.TransformerDecoder(dlayer, num_layers=6); self.output_layer = nn.Linear(768, config.vocab_size)
|
101 |
+
def forward(self, src, tgt): return self.output_layer(self.decoder(tgt, self.encoder(src)))
|
102 |
+
class ResnetBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.norm1 = nn.GroupNorm(32, in_ch); self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1); self.norm2 = nn.GroupNorm(32, out_ch); self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1); self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
|
103 |
+
def forward(self, x): sc = self.conv_shortcut(x); h = F.silu(self.norm1(x)); h = self.conv1(h); h = F.silu(self.norm2(x)); h = self.conv2(h); return h + sc
|
104 |
+
class Downsample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
|
105 |
+
def forward(self, x): return self.conv(x)
|
106 |
+
class DownBlock(nn.Module): def __init__(self, in_ch, out_ch, num_res): super().__init__(); self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]); self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
|
107 |
+
def forward(self, x): for r in self.resnets: x = r(x); for ds in self.downsamplers: x = ds(x); return x
|
108 |
+
class Upsample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
|
109 |
+
def forward(self, x): return self.conv(x)
|
110 |
+
class UpBlock(nn.Module): def __init__(self, in_ch, out_ch, num_res): super().__init__(); self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]); self.upsampler = Upsample(out_ch, out_ch)
|
111 |
+
def forward(self, x): for r in self.resnets: x = r(x); return self.upsampler(x)
|
112 |
+
class AttentionBlock(nn.Module): def __init__(self, ch): super().__init__(); self.norm = nn.GroupNorm(32, ch); self.query = nn.Conv2d(ch, ch, 1); self.key = nn.Conv2d(ch, ch, 1); self.value = nn.Conv2d(ch, ch, 1); self.proj_attn = nn.Conv2d(ch, ch, 1)
|
113 |
+
def forward(self, x): b, c, h, w = x.shape; xn = self.norm(x); q = self.query(xn).view(b, c, -1).permute(0, 2, 1); k = self.key(xn).view(b, c, -1); v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
|
114 |
+
attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1); out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w); return x + self.proj_attn(out)
|
115 |
+
class Encoder(nn.Module): def __init__(self, in_ch=3, base_ch=128, latent_ch=4): super().__init__(); self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1); self.down_blocks = nn.ModuleList([DownBlock(base_ch, base_ch, 2), DownBlock(base_ch, base_ch * 2, 2), DownBlock(base_ch * 2, base_ch * 4, 2), DownBlock(base_ch * 4, base_ch * 4, 2)]); self.mid_block = nn.ModuleList([ResnetBlock(base_ch * 4, base_ch * 4), AttentionBlock(base_ch * 4), ResnetBlock(base_ch * 4, base_ch * 4)]); self.conv_norm_out = nn.GroupNorm(32, base_ch * 4); self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1); self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
|
116 |
+
def forward(self, x): x = self.conv_in(x); for blk in self.down_blocks: x = blk(x); for m in self.mid_block: x = m(x); x = self.conv_norm_out(x); x = self.conv_out(x); return self.quant_conv(x)
|
117 |
+
class Decoder(nn.Module): def __init__(self, out_ch=3, base_ch=128, latent_ch=4): super().__init__(); self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1); self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1); self.mid_block = nn.ModuleList([ResnetBlock(base_ch * 4, base_ch * 4), AttentionBlock(base_ch * 4), ResnetBlock(base_ch * 4, base_ch * 4)]); self.up_blocks = nn.ModuleList([UpBlock(base_ch * 4, base_ch * 4, 3), UpBlock(base_ch * 4, base_ch * 2, 3), UpBlock(base_ch * 2, base_ch, 3), UpBlock(base_ch, base_ch, 3)]); self.conv_norm_out = nn.GroupNorm(32, base_ch); self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
|
118 |
+
def forward(self, x): x = self.post_quant_conv(x); x = self.conv_in(x); for m in self.mid_block: x = m(x); for up in self.up_blocks: x = up(x); x = self.conv_norm_out(x); return self.conv_out(x)
|
119 |
+
class AutoencoderKL(nn.Module): def __init__(self, config): super().__init__(); in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
|
120 |
+
out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3); base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
|
121 |
+
latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4); self.encoder = Encoder(in_ch, base_ch, latent_ch); self.decoder = Decoder(out_ch, base_ch, latent_ch)
|
122 |
+
def forward(self, x): return self.decoder(self.encoder(x))
|
123 |
+
def decode(self, x): return self.decoder(x)
|
124 |
+
class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__(); self.norm1 = nn.LayerNorm(embed_dim); self.attn = nn.MultiheadAttention(embed_dim, num_heads); self.norm2 = nn.LayerNorm(embed_dim)
|
125 |
+
hidden_dim = embed_dim * 4; self.mlp = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
|
126 |
+
def forward(self, x): res = x; x = self.norm1(x); x = x.transpose(0, 1); attn, _ = self.attn(x, x, x); x = attn.transpose(0, 1); x = res + x; return x + self.mlp(self.norm2(x))
|
127 |
+
class VisionTransformer(nn.Module): def __init__(self, config): super().__init__(); self.img_size = config.get("img_size", 592)
|
128 |
+
self.patch_size = config.get("patch_size", 16); self.embed_dim = config.get("hidden_size", 768); depth = config.get("depth", 12); num_heads = config.get("num_heads", 12)
|
129 |
+
num_patches = (self.img_size // self.patch_size) ** 2; self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)); self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
|
130 |
+
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size); self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)]); self.norm = nn.LayerNorm(self.embed_dim)
|
131 |
+
self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim)); self._init_weights()
|
132 |
+
def _init_weights(self): nn.init.normal_(self.cls_token, std=0.02); nn.init.normal_(self.pos_embed, std=0.02)
|
133 |
+
def forward(self, x): x = self.patch_embed(x); x = x.flatten(2).transpose(1, 2); cls_tokens = self.cls_token.expand(x.shape[0], -1, -1); x = torch.cat((cls_tokens, x), dim=1); x = x + self.pos_embed
|
134 |
+
for blk in self.blocks: x = blk(x); return self.norm(x)[:, 0]
|
135 |
+
class OpenLRM(nn.Module): def __init__(self, config): super().__init__(); self.encoder = nn.ModuleDict({"model": VisionTransformer(config)}); hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768); self.linear = nn.Linear(hidden, hidden)
|
136 |
+
def forward(self, x): return self.linear(self.encoder["model"](x))
|
137 |
+
class VideoUNet(nn.Module): def __init__(self, in_ch=4, out_ch=4, features=None): super().__init__();
|
138 |
+
if features is None: features = [64, 128, 256]
|
139 |
+
self.encoder = nn.ModuleList(); self.pool = nn.MaxPool3d(2, 2); self.decoder = nn.ModuleList()
|
140 |
+
for f in features: self.encoder.append(nn.Sequential(nn.Conv3d(in_ch, f, 3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(f, f, 3, padding=1), nn.ReLU(inplace=True))); in_ch = f
|
141 |
+
for f in reversed(features): self.decoder.append(nn.Sequential(nn.Conv3d(f * 2, f, 3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(f, f, 3, padding=1), nn.ReLU(inplace=True)))
|
142 |
+
self.final_conv = nn.Conv3d(features[0], out_ch, 1)
|
143 |
+
def forward(self, x, t, encoder_hidden_states): skips = []; for enc in self.encoder: x = enc(x); skips.append(x); x = self.pool(x)
|
144 |
+
for dec in self.decoder: skip = skips.pop(); x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False); x = torch.cat([x, skip], dim=1); x = dec(x)
|
145 |
+
return self.final_conv(x)
|
146 |
+
class SentimentClassifierModel(nn.Module): def __init__(self, config): super().__init__(); self.classifier = nn.Sequential(nn.Linear(768, 256), nn.ReLU(), nn.Linear(256, 2))
|
147 |
+
def forward(self, x): return self.classifier(x)
|
148 |
+
class STTModel(nn.Module): def __init__(self, config): super().__init__(); self.net = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 768))
|
149 |
+
def forward(self, x): return self.net(x)
|
150 |
+
class TTSModel(nn.Module): def __init__(self, config): super().__init__(); self.net = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 768))
|
151 |
+
def forward(self, x): return self.net(x)
|
152 |
+
class MusicGenModel(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.transformer = nn.TransformerEncoder(layer, num_layers=12); self.linear = nn.Linear(768, 768)
|
153 |
+
def forward(self, x): return self.linear(self.transformer(x))
|
154 |
+
class SimpleTextEncoder(nn.Module): def __init__(self, vocab_size=10000, embed_dim=768, max_length=77): super().__init__(); self.embedding = nn.Embedding(vocab_size, embed_dim); self.max_length = max_length
|
155 |
+
def forward(self, text_tokens): return self.embedding(text_tokens)
|
156 |
+
class DiffusionScheduler: def __init__(self, steps): self.steps = steps; self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device); self.alphas = 1 - self.betas; self.alpha_bars = torch.cumprod(self.alphas, dim=0)
|
157 |
+
def step(self, noise, t, sample): alpha_bar = self.alpha_bars[t]; alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device)
|
158 |
+
x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar); new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise; return new_sample
|
159 |
+
class VideoOutput: def __init__(self, frames): self.frames = [img_as_ubyte(frame) for frame in frames[0]]
|
160 |
+
class VideoPipeline(nn.Module): def __init__(self, unet, vae, text_encoder, vocab): super().__init__(); self.unet = unet; self.vae = vae; self.text_encoder = text_encoder; self.vocab = vocab
|
161 |
+
def forward(self, prompt: str, steps: int = 25, num_frames: int = 24): token_ids = simple_tokenizer(prompt, self.vocab); text_emb = self.text_encoder(token_ids)
|
162 |
+
latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half(); sched = DiffusionScheduler(steps)
|
163 |
+
for t in range(steps): noise = self.unet(latent, t, text_emb); latent = sched.step(noise, t, latent)
|
164 |
+
frames = self.vae.decode(latent / 0.18215); frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy(); return VideoOutput(frames)
|
165 |
+
class XTTSModelClass(nn.Module):
|
166 |
+
def __init__(self, config):
|
167 |
+
super().__init__()
|
168 |
+
self.xtts = XTTSModel(config, num_speakers=1024, num_languages=25) # Adjust num_speakers, num_languages as needed
|
169 |
+
|
170 |
+
def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths):
|
171 |
+
return self.xtts.forward(text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths)
|
172 |
+
|
173 |
+
def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0):
|
174 |
+
return self.xtts.inference(text, language_id, speaker_id, voice_sample, temperature, length_penalty)
|
175 |
+
|
176 |
+
|
177 |
+
def initialize_gpt2_model(folder, files): download_files(folder, files); config = GPT2Config(); model = GPT2LMHeadModel(config).to(device)
|
178 |
+
sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device); load_state_dict_safe(model, sd); model.eval(); enc = read_json(os.path.join(folder, sanitize_filename("encoder.json"))); return model, enc
|
179 |
+
def initialize_translation_model(folder, files): download_files(folder, files); config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
180 |
+
model = MBartForConditionalGeneration(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
|
181 |
+
vp = os.path.join(folder, "vocab.json");
|
182 |
+
if os.path.exists(vp): vocab = read_json(vp); model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
|
183 |
+
else: model.tokenizer = lambda txt: txt
|
184 |
+
model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}; return model
|
185 |
+
def initialize_codegen_model(folder, files): download_files(folder, files); config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
186 |
+
model = CodeGenForCausalLM(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
|
187 |
+
tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt")); vocab = read_json(os.path.join(folder, "vocab.json")); idx2w = {v: k for k, v in vocab.items()}
|
188 |
+
model.tokenizer = tok; return model, tok, vocab, idx2w, vocab
|
189 |
+
def initialize_summarization_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
190 |
+
model = BartForConditionalGeneration(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
|
191 |
+
vp = os.path.join(folder, "vocab.json");
|
192 |
+
if os.path.exists(vp): vocab_json = read_json(vp); vocab = set(vocab_json.keys()); return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
|
193 |
+
return model, None, None, None
|
194 |
+
def initialize_imagegen_model(folder, files): download_files(folder, files); config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
195 |
+
vae = AutoencoderKL(config).to(device); sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device); load_state_dict_safe(vae, sd); vae.eval(); return vae
|
196 |
+
def initialize_image_to_3d_model(folder, files): download_files(folder, files); config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
197 |
+
model3d = OpenLRM(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model3d, sd); model3d.eval(); return model3d
|
198 |
+
def initialize_text_to_video_model(folder, files): download_files(folder, files); unet_cfg = read_json(os.path.join(folder, "config.json"))
|
199 |
+
unet_cfg = filter_kwargs(VideoUNet, unet_cfg); unet = VideoUNet(**unet_cfg).half().to(device)
|
200 |
+
sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device); load_state_dict_safe(unet, sd_unet); unet.eval()
|
201 |
+
vae_cfg = read_json(os.path.join(folder, "config.json")); vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg); vae = AutoencoderKL(vae_cfg).half().to(device)
|
202 |
+
sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device); load_state_dict_safe(vae, sd_vae); vae.eval()
|
203 |
+
vp = os.path.join(folder, "vocab.json"); text_vocab = read_json(vp) if os.path.exists(vp) else {}; te_path = os.path.join(folder, "text_encoder.bin")
|
204 |
+
if os.path.exists(te_path): text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device); sd_te = torch.load(te_path, map_location=device); load_state_dict_safe(text_encoder, sd_te)
|
205 |
+
else: text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
|
206 |
+
text_encoder.eval(); return VideoPipeline(unet, vae, text_encoder, text_vocab)
|
207 |
+
def initialize_sentiment_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
208 |
+
model = SentimentClassifierModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
|
209 |
+
vp = os.path.join(folder, "vocab.json");
|
210 |
+
if os.path.exists(vp): read_json(vp); return model
|
211 |
+
def initialize_stt_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
212 |
+
model = STTModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
|
213 |
+
vp = os.path.join(folder, "vocab.json");
|
214 |
+
if os.path.exists(vp): read_json(vp); return model
|
215 |
+
def initialize_tts_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
216 |
+
model = TTSModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
|
217 |
+
vp = os.path.join(folder, "vocab.json");
|
218 |
+
if os.path.exists(vp): read_json(vp); return model
|
219 |
+
def initialize_musicgen_model(folder, files): download_files(folder, files); config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
220 |
+
model = MusicGenModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval(); return model
|
221 |
+
|
222 |
+
def initialize_xtts_model(folder, files):
|
223 |
+
download_files(folder, files)
|
224 |
+
config_xtts = XTTSConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
225 |
+
model = XTTSModelClass(config_xtts).to(device)
|
226 |
+
checkpoint = torch.load(os.path.join(folder, "model.pth"), map_location=torch.device(device))
|
227 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
228 |
+
model.eval()
|
229 |
+
return model.xtts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
CHANGED
@@ -2,9 +2,11 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
import math
|
5 |
-
import
|
6 |
-
from
|
7 |
-
from
|
|
|
|
|
8 |
|
9 |
class SentimentClassifierModel(nn.Module):
|
10 |
def __init__(self, config):
|
@@ -13,15 +15,12 @@ class SentimentClassifierModel(nn.Module):
|
|
13 |
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
14 |
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
|
15 |
self.fc = nn.Linear(config.d_model * 2, 3)
|
16 |
-
|
17 |
def forward(self, input_ids):
|
18 |
embedded = self.embedding(input_ids)
|
19 |
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
|
20 |
packed_output, _ = self.lstm(packed_embedded)
|
21 |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
22 |
-
pooled = output[:, -1, :]
|
23 |
-
logits = self.fc(pooled)
|
24 |
-
return logits
|
25 |
|
26 |
class STTModel(nn.Module):
|
27 |
def __init__(self, config):
|
@@ -35,17 +34,11 @@ class STTModel(nn.Module):
|
|
35 |
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
|
36 |
self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
|
37 |
self.fc = nn.Linear(128 * 2, config.vocab_size)
|
38 |
-
|
39 |
def forward(self, audio_data):
|
40 |
x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
|
41 |
-
x = self.pool2(self.relu2(self.conv2(x)))
|
42 |
-
|
43 |
-
|
44 |
-
packed_output = nn.utils.rnn.pack_padded_sequence(x, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False)
|
45 |
-
packed_output, _ = self.lstm(packed_output)
|
46 |
-
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
47 |
-
logits = self.fc(output)
|
48 |
-
return logits
|
49 |
|
50 |
class TTSModel(nn.Module):
|
51 |
def __init__(self, config):
|
@@ -55,15 +48,9 @@ class TTSModel(nn.Module):
|
|
55 |
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
|
56 |
self.fc = nn.Linear(config.d_model * 2, 1)
|
57 |
self.sigmoid = nn.Sigmoid()
|
58 |
-
|
59 |
def forward(self, input_ids):
|
60 |
-
embedded = self.embedding(input_ids)
|
61 |
-
packed_embedded = nn.utils.rnn.
|
62 |
-
packed_output, _ = self.lstm(packed_embedded)
|
63 |
-
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
64 |
-
logits = self.fc(output)
|
65 |
-
audio = self.sigmoid(logits)
|
66 |
-
return audio
|
67 |
|
68 |
class MusicGenModel(nn.Module):
|
69 |
def __init__(self, config: MusicGenConfig):
|
@@ -72,23 +59,47 @@ class MusicGenModel(nn.Module):
|
|
72 |
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
73 |
self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
|
74 |
self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)
|
75 |
-
|
76 |
def forward(self, input_ids):
|
77 |
-
embedded_tokens = self.embedding(input_ids)
|
78 |
-
hidden_states =
|
79 |
-
|
80 |
-
hidden_states = layer(hidden_states)
|
81 |
-
logits = self.fc_out(hidden_states)
|
82 |
-
return logits
|
83 |
-
|
84 |
def sample(self, attributes, sample_rate, duration):
|
85 |
-
input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
import math
|
5 |
+
from configs import MusicGenConfig
|
6 |
+
from extensions import CodeGenBlock
|
7 |
+
from TTS.tts.layers.xtts.transformer import XTransformerEncoder, XTransformerDecoder
|
8 |
+
from TTS.tts.layers.xtts.flow import VitsFlowModules
|
9 |
+
from TTS.tts.layers.xtts.tokenizer import VoiceBPE
|
10 |
|
11 |
class SentimentClassifierModel(nn.Module):
|
12 |
def __init__(self, config):
|
|
|
15 |
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
16 |
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
|
17 |
self.fc = nn.Linear(config.d_model * 2, 3)
|
|
|
18 |
def forward(self, input_ids):
|
19 |
embedded = self.embedding(input_ids)
|
20 |
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
|
21 |
packed_output, _ = self.lstm(packed_embedded)
|
22 |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
23 |
+
pooled = output[:, -1, :]; logits = self.fc(pooled); return logits
|
|
|
|
|
24 |
|
25 |
class STTModel(nn.Module):
|
26 |
def __init__(self, config):
|
|
|
34 |
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
|
35 |
self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
|
36 |
self.fc = nn.Linear(128 * 2, config.vocab_size)
|
|
|
37 |
def forward(self, audio_data):
|
38 |
x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
|
39 |
+
x = self.pool2(self.relu2(self.conv2(x))); x = x.transpose(1, 2).contiguous(); x = x.view(x.size(0), -1, x.size(2))
|
40 |
+
packed_output = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False); packed_output, _ = self.lstm(packed_output)
|
41 |
+
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True); logits = self.fc(output); return logits
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
class TTSModel(nn.Module):
|
44 |
def __init__(self, config):
|
|
|
48 |
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
|
49 |
self.fc = nn.Linear(config.d_model * 2, 1)
|
50 |
self.sigmoid = nn.Sigmoid()
|
|
|
51 |
def forward(self, input_ids):
|
52 |
+
embedded = self.embedding(input_ids); packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
|
53 |
+
packed_output, _ = self.lstm(packed_embedded); output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True); logits = self.fc(output); audio = self.sigmoid(logits); return audio
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
class MusicGenModel(nn.Module):
|
56 |
def __init__(self, config: MusicGenConfig):
|
|
|
59 |
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
60 |
self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
|
61 |
self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)
|
|
|
62 |
def forward(self, input_ids):
|
63 |
+
embedded_tokens = self.embedding(input_ids); hidden_states = embedded_tokens
|
64 |
+
for layer in self.transformer_layers: hidden_states = layer(hidden_states)
|
65 |
+
logits = self.fc_out(hidden_states); return logits
|
|
|
|
|
|
|
|
|
66 |
def sample(self, attributes, sample_rate, duration):
|
67 |
+
input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device); audio_output = []; num_steps = int(duration * sample_rate / 1024)
|
68 |
+
for _ in tqdm(range(num_steps), desc="Generating music"): logits = self.forward(input_tokens); predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True); audio_output.append(predicted_token.cpu()); input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
|
69 |
+
audio_output = torch.cat(audio_output, dim=1).float(); return audio_output
|
70 |
+
|
71 |
+
class XTTSModelClass(nn.Module):
|
72 |
+
def __init__(self, config):
|
73 |
+
super().__init__()
|
74 |
+
self.xtts = XTTSModel(config, num_speakers=1024, num_languages=25)
|
75 |
+
|
76 |
+
def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths):
|
77 |
+
return self.xtts.forward(text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths)
|
78 |
+
|
79 |
+
def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0):
|
80 |
+
return self.xtts.inference(text, language_id, speaker_id, voice_sample, temperature, length_penalty)
|
81 |
+
|
82 |
+
class XTTSModel(nn.Module):
|
83 |
+
def __init__(self, config, num_speakers, num_languages):
|
84 |
+
super().__init__()
|
85 |
+
self.config = config
|
86 |
+
self.num_speakers = num_speakers
|
87 |
+
self.num_languages = num_languages
|
88 |
+
self.encoder = XTransformerEncoder(**config.encoder_config)
|
89 |
+
self.decoder = XTransformerDecoder(**config.decoder_config)
|
90 |
+
self.flow_modules = VitsFlowModules(**config.flow_config)
|
91 |
+
self.voice_tokenizer = VoiceBPE(vocab_path=config.voice_tokenizer_config.vocab_path, vocab_size=config.voice_tokenizer_config.vocab_size)
|
92 |
+
self.language_embedding = nn.Embedding(num_languages, config.embedding_dim)
|
93 |
+
self.speaker_embedding = nn.Embedding(num_speakers, config.embedding_dim)
|
94 |
+
self.text_embedding = nn.Embedding(config.num_chars, config.embedding_dim)
|
95 |
+
|
96 |
+
def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths):
|
97 |
+
lang_embed = self.language_embedding(language_ids); spk_embed = self.speaker_embedding(speaker_ids); text_embed = self.text_embedding(text_tokens)
|
98 |
+
encoder_outputs, _ = self.encoder(text_embed, text_lengths, lang_embed + spk_embed); mel_outputs, _ = self.decoder(encoder_outputs, lang_embed + spk_embed, voice_samples); return mel_outputs, None
|
99 |
+
|
100 |
+
def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0):
|
101 |
+
language_ids = torch.tensor([language_id], dtype=torch.long).to(device); speaker_ids = torch.tensor([speaker_id], dtype=torch.long).to(device)
|
102 |
+
text_tokens = self.voice_tokenizer.text_to_ids(text).to(device); text_lengths = torch.tensor([text_tokens.shape[0]], dtype=torch.long).to(device); voice_sample_lengths = torch.tensor([voice_sample.shape[0]], dtype=torch.long).to(device)
|
103 |
+
lang_embed = self.language_embedding(language_ids); spk_embed = self.speaker_embedding(speaker_ids); text_embed = self.text_embedding(text_tokens)
|
104 |
+
encoder_outputs, _ = self.encoder(text_embed, text_lengths, lang_embed + spk_embed); mel_outputs, _ = self.decoder.inference(encoder_outputs, lang_embed + spk_embed, voice_sample, temperature=temperature, length_penalty=length_penalty)
|
105 |
+
return mel_outputs
|
musicgen_api.py
CHANGED
@@ -1,34 +1,15 @@
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
-
import torch
|
4 |
-
import soundfile as sf
|
5 |
-
import numpy as np
|
6 |
-
import io
|
7 |
|
8 |
def generate_music(prompt, output_path="output_music.wav"):
|
9 |
-
if musicgen_model is None:
|
10 |
-
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
sample_rate = 32000
|
14 |
-
duration = 10
|
15 |
-
audio_values = musicgen_model.sample(
|
16 |
-
attributes=attributes,
|
17 |
-
sample_rate=sample_rate,
|
18 |
-
duration=duration,
|
19 |
-
)
|
20 |
-
output_audio = audio_values.cpu().numpy().squeeze()
|
21 |
-
sf.write(output_path, output_audio, sample_rate)
|
22 |
-
return output_path
|
23 |
-
|
24 |
-
def musicgen_api():
|
25 |
-
data = request.get_json()
|
26 |
-
prompt = data.get('prompt')
|
27 |
-
if not prompt:
|
28 |
-
return jsonify({"error": "Prompt is required"}), 400
|
29 |
output_file = generate_music(prompt)
|
30 |
-
if output_file
|
31 |
-
|
32 |
-
|
33 |
-
audio_content = f.read()
|
34 |
-
return send_file(io.BytesIO(audio_content), mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
+
import torch, soundfile as sf, numpy as np, io, base64
|
|
|
|
|
|
|
4 |
|
5 |
def generate_music(prompt, output_path="output_music.wav"):
|
6 |
+
if musicgen_model is None: return {"error": "Music generation model not initialized."}
|
7 |
+
attributes = [prompt]; sample_rate = 32000; duration = 8
|
8 |
+
audio_values = musicgen_model.sample(attributes=attributes, sample_rate=sample_rate, duration=duration); output_audio = audio_values.cpu().numpy().squeeze()
|
9 |
+
sf.write(output_path, output_audio, sample_rate); return output_path
|
10 |
|
11 |
+
def musicgen_api(prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
output_file = generate_music(prompt)
|
13 |
+
if isinstance(output_file, dict) and "error" in output_file: return {"error": output_file["error"]}
|
14 |
+
with open(output_file, 'rb') as f: audio_content = f.read()
|
15 |
+
audio_base64 = base64.b64encode(audio_content).decode('utf-8'); os.remove(output_file); return {"audio_base64": audio_base64, "mimetype": "audio/wav"}
|
|
|
|
sadtalker_api.py
CHANGED
@@ -1,9 +1,4 @@
|
|
1 |
-
import os
|
2 |
-
import tempfile
|
3 |
-
import uuid
|
4 |
-
import asyncio
|
5 |
-
import shutil
|
6 |
-
import requests
|
7 |
from urllib.parse import urlparse
|
8 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket
|
9 |
from fastapi.responses import JSONResponse
|
@@ -19,186 +14,24 @@ from text_generation import *
|
|
19 |
router = APIRouter()
|
20 |
|
21 |
@router.post("/sadtalker")
|
22 |
-
async def create_video(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
batch_size: int = Form(1),
|
31 |
-
size: int = Form(256),
|
32 |
-
pose_style: int = Form(0),
|
33 |
-
exp_scale: float = Form(1.0),
|
34 |
-
use_ref_video: bool = Form(False),
|
35 |
-
ref_video: str = Form(None),
|
36 |
-
ref_video_file: UploadFile = File(None),
|
37 |
-
ref_info: str = Form(None),
|
38 |
-
use_idle_mode: bool = Form(False),
|
39 |
-
length_of_audio: int = Form(0),
|
40 |
-
use_blink: bool = Form(True),
|
41 |
-
checkpoint_dir: str = Form('checkpoints'),
|
42 |
-
config_dir: str = Form('src/config'),
|
43 |
-
old_version: bool = Form(False),
|
44 |
-
tts_text: str = Form(None),
|
45 |
-
tts_lang: str = Form('en'),
|
46 |
-
):
|
47 |
-
if source_image_file and source_image:
|
48 |
-
raise HTTPException(status_code=400, detail="source_image and source_image_file cannot be both not None")
|
49 |
-
if driven_audio and driven_audio_file:
|
50 |
-
raise HTTPException(status_code=400, detail="driven_audio and driven_audio_file cannot be both not None")
|
51 |
-
if ref_video and ref_video_file:
|
52 |
-
raise HTTPException(status_code=400, detail="ref_video and ref_video_file cannot be both not None")
|
53 |
-
tmp_source_image = None
|
54 |
-
if source_image_file:
|
55 |
-
tmp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False)
|
56 |
-
content = await source_image_file.read()
|
57 |
-
tmp_source_image.write(content)
|
58 |
-
source_image_path = tmp_source_image.name
|
59 |
-
elif source_image:
|
60 |
-
if urlparse(source_image).scheme in ["http", "https"]:
|
61 |
-
response = requests.get(source_image, stream=True)
|
62 |
-
response.raise_for_status()
|
63 |
-
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_source_image:
|
64 |
-
for chunk in response.iter_content(chunk_size=8192):
|
65 |
-
tmp_source_image.write(chunk)
|
66 |
-
source_image_path = tmp_source_image.name
|
67 |
-
else:
|
68 |
-
source_image_path = source_image
|
69 |
-
else:
|
70 |
-
raise HTTPException(status_code=400, detail="source_image not provided")
|
71 |
-
tmp_driven_audio = None
|
72 |
-
if driven_audio_file:
|
73 |
-
tmp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False)
|
74 |
-
content = await driven_audio_file.read()
|
75 |
-
tmp_driven_audio.write(content)
|
76 |
-
driven_audio_path = tmp_driven_audio.name
|
77 |
-
elif driven_audio:
|
78 |
-
if urlparse(driven_audio).scheme in ["http", "https"]:
|
79 |
-
response = requests.get(driven_audio, stream=True)
|
80 |
-
response.raise_for_status()
|
81 |
-
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_driven_audio:
|
82 |
-
for chunk in response.iter_content(chunk_size=8192):
|
83 |
-
tmp_driven_audio.write(chunk)
|
84 |
-
driven_audio_path = tmp_driven_audio.name
|
85 |
-
else:
|
86 |
-
driven_audio_path = driven_audio
|
87 |
-
else:
|
88 |
-
driven_audio_path = None
|
89 |
-
tmp_ref_video = None
|
90 |
-
if ref_video_file:
|
91 |
-
tmp_ref_video = tempfile.NamedTemporaryFile(suffix=os.path.splitext(ref_video_file.filename)[1], delete=False)
|
92 |
-
content = await ref_video_file.read()
|
93 |
-
tmp_ref_video.write(content)
|
94 |
-
ref_video_path = tmp_ref_video.name
|
95 |
-
elif ref_video:
|
96 |
-
if urlparse(ref_video).scheme in ["http", "https"]:
|
97 |
-
response = requests.get(ref_video, stream=True)
|
98 |
-
response.raise_for_status()
|
99 |
-
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_ref_video:
|
100 |
-
for chunk in response.iter_content(chunk_size=8192):
|
101 |
-
tmp_ref_video.write(chunk)
|
102 |
-
ref_video_path = tmp_ref_video.name
|
103 |
-
else:
|
104 |
-
ref_video_path = ref_video
|
105 |
-
else:
|
106 |
-
ref_video_path=None
|
107 |
-
try:
|
108 |
-
loop = asyncio.get_running_loop()
|
109 |
-
output_path = await loop.run_in_executor(None, sadtalker_instance.test,
|
110 |
-
source_image_path,
|
111 |
-
driven_audio_path,
|
112 |
-
preprocess,
|
113 |
-
still_mode,
|
114 |
-
use_enhancer,
|
115 |
-
batch_size,
|
116 |
-
size,
|
117 |
-
pose_style,
|
118 |
-
exp_scale,
|
119 |
-
use_ref_video,
|
120 |
-
ref_video_path,
|
121 |
-
ref_info,
|
122 |
-
use_idle_mode,
|
123 |
-
length_of_audio,
|
124 |
-
use_blink,
|
125 |
-
'./results/',
|
126 |
-
tts_text=tts_text,
|
127 |
-
tts_lang=tts_lang,
|
128 |
-
)
|
129 |
-
return {"video_url": output_path}
|
130 |
-
except Exception as e:
|
131 |
-
raise HTTPException(status_code=500, detail=str(e))
|
132 |
-
finally:
|
133 |
-
if tmp_source_image:
|
134 |
-
os.remove(tmp_source_image.name)
|
135 |
-
if tmp_driven_audio:
|
136 |
-
os.remove(tmp_driven_audio.name)
|
137 |
-
if tmp_ref_video:
|
138 |
-
os.remove(tmp_ref_video.name)
|
139 |
|
140 |
-
@router.websocket("/ws")
|
141 |
-
async def websocket_endpoint(websocket: WebSocket):
|
142 |
-
await websocket.accept()
|
143 |
-
tts_model = TTSTalker()
|
144 |
try:
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, text)
|
151 |
-
elif audio_base64:
|
152 |
-
try:
|
153 |
-
audio_bytes = base64.b64decode(audio_base64)
|
154 |
-
tmp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
155 |
-
tmp_audio_file.write(audio_bytes)
|
156 |
-
audio_path = tmp_audio_file.name
|
157 |
-
transcription_text_file = speech_to_text_func(tmp_audio_file.name)
|
158 |
-
with open(transcription_text_file, 'r') as f:
|
159 |
-
transcription_text = f.read()
|
160 |
-
response_stream = perform_reasoning_stream(transcription_text, 0.7, 40, 0.0, 1.2)
|
161 |
-
response_text = ""
|
162 |
-
for chunk in response_stream:
|
163 |
-
if chunk == "<END_STREAM>":
|
164 |
-
break
|
165 |
-
response_text += chunk
|
166 |
-
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, response_text)
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
continue
|
171 |
-
finally:
|
172 |
-
if 'tmp_audio_file' in locals() and tmp_audio_file:
|
173 |
-
os.remove(tmp_audio_file.name)
|
174 |
-
else:
|
175 |
-
continue
|
176 |
-
source_image_path = './examples/source_image/cyarh.png'
|
177 |
-
ref_video_path='./examples/driven_video/vid_xdd.mp4'
|
178 |
-
loop = asyncio.get_running_loop()
|
179 |
-
output = await loop.run_in_executor(None, sadtalker_instance.test,
|
180 |
-
source_image_path,
|
181 |
-
audio_path,
|
182 |
-
'full',
|
183 |
-
True,
|
184 |
-
True,
|
185 |
-
1,
|
186 |
-
256,
|
187 |
-
0,
|
188 |
-
1,
|
189 |
-
True,
|
190 |
-
ref_video_path,
|
191 |
-
"pose+blink",
|
192 |
-
False,
|
193 |
-
0,
|
194 |
-
True,
|
195 |
-
'./results/'
|
196 |
-
)
|
197 |
-
await websocket.send_json({"video_url": output})
|
198 |
-
except Exception as e:
|
199 |
-
print(e)
|
200 |
-
await websocket.send_json({"error":str(e)})
|
201 |
|
202 |
router = APIRouter()
|
203 |
router.add_api_route("/sadtalker", create_video, methods=["POST"])
|
204 |
-
router.add_api_websocket_route("/ws", websocket_endpoint)
|
|
|
1 |
+
import os, tempfile, uuid, asyncio, shutil, requests
|
|
|
|
|
|
|
|
|
|
|
2 |
from urllib.parse import urlparse
|
3 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket
|
4 |
from fastapi.responses import JSONResponse
|
|
|
14 |
router = APIRouter()
|
15 |
|
16 |
@router.post("/sadtalker")
|
17 |
+
async def create_video(source_image_file: UploadFile = File(...), driven_audio_file: UploadFile = File(...)):
|
18 |
+
if not source_image_file: raise HTTPException(status_code=400, detail="Source image file is required")
|
19 |
+
if not driven_audio_file: raise HTTPException(status_code=400, detail="Driven audio file is required")
|
20 |
+
|
21 |
+
temp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False)
|
22 |
+
content_image = await source_image_file.read(); temp_source_image.write(content_image); source_image_path = temp_source_image.name
|
23 |
+
temp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False)
|
24 |
+
content_audio = await driven_audio_file.read(); temp_driven_audio.write(content_audio); driven_audio_path = temp_driven_audio.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
|
|
|
|
|
|
|
|
26 |
try:
|
27 |
+
loop = asyncio.get_running_loop()
|
28 |
+
output_path = await loop.run_in_executor(None, sadtalker_instance.test, source_image_path, driven_audio_path)
|
29 |
+
video_base64 = None
|
30 |
+
with open(output_path, 'rb') as video_file: video_bytes = video_file.read(); video_base64 = base64.b64encode(video_bytes).decode('utf-8')
|
31 |
+
os.remove(output_path); return {"video_base64": video_base64, "mimetype": "video/mp4"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
except Exception as e: raise HTTPException(status_code=500, detail=str(e))
|
34 |
+
finally: os.remove(temp_source_image.name); os.remove(temp_driven_audio.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
router = APIRouter()
|
37 |
router.add_api_route("/sadtalker", create_video, methods=["POST"])
|
|
sadtalker_utils.py
CHANGED
@@ -1,820 +1,209 @@
|
|
1 |
-
import os
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
def
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
self.
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
self.
|
117 |
-
self.
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
def
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
def
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
def __init__(self,
|
141 |
-
super(
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
self.
|
150 |
-
|
151 |
-
def
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
def
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
def
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
self.
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
self.
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
def detect(self, image):
|
212 |
-
h, w = image.shape[:2]
|
213 |
-
return (w // 2, h // 2)
|
214 |
-
|
215 |
-
|
216 |
-
class KeypointNorm(nn.Module):
|
217 |
-
|
218 |
-
def __init__(self, device):
|
219 |
-
super(KeypointNorm, self).__init__()
|
220 |
-
self.device = device
|
221 |
-
|
222 |
-
def forward(self, kp_driving):
|
223 |
-
return kp_driving
|
224 |
-
|
225 |
-
|
226 |
-
def save_video_with_watermark(video_frames, audio_path, output_path):
|
227 |
-
H, W, _ = video_frames[0].shape
|
228 |
-
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
|
229 |
-
for frame in video_frames:
|
230 |
-
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
231 |
-
out.release()
|
232 |
-
|
233 |
-
|
234 |
-
def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
|
235 |
-
shutil.copy(video_path, output_path)
|
236 |
-
|
237 |
-
|
238 |
-
class TTSTalker:
|
239 |
-
|
240 |
-
def __init__(self):
|
241 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
242 |
-
self.tts_model = None
|
243 |
-
|
244 |
-
def load_model(self):
|
245 |
-
self.tts_model = self
|
246 |
-
|
247 |
-
def tokenizer(self, text):
|
248 |
-
return [ord(c) for c in text]
|
249 |
-
|
250 |
-
def __call__(self, input_tokens):
|
251 |
-
return torch.zeros(1, 16000, device=self.device)
|
252 |
-
|
253 |
-
def test(self, text, lang='en'):
|
254 |
-
if self.tts_model is None:
|
255 |
-
self.load_model()
|
256 |
-
output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
|
257 |
-
os.makedirs('./results', exist_ok=True)
|
258 |
-
tokens = self.tokenizer(text)
|
259 |
-
input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
|
260 |
-
with torch.no_grad():
|
261 |
-
audio_output = self(input_tokens)
|
262 |
-
torchaudio.save(output_path, audio_output.cpu(), 16000)
|
263 |
-
return output_path
|
264 |
-
|
265 |
-
|
266 |
-
class SadTalker:
|
267 |
-
|
268 |
-
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
|
269 |
-
old_version=False):
|
270 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
271 |
-
self.cfg = self.get_cfg_defaults()
|
272 |
-
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
|
273 |
-
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
|
274 |
-
self.cfg['MODEL']['CONFIG_DIR'] = config_path
|
275 |
-
self.cfg['MODEL']['DEVICE'] = self.device
|
276 |
-
self.cfg['INPUT_IMAGE'] = {}
|
277 |
-
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
|
278 |
-
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
|
279 |
-
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
|
280 |
-
self.cfg['INPUT_IMAGE']['SIZE'] = size
|
281 |
-
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
|
282 |
-
|
283 |
-
for filename, url in [
|
284 |
-
(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url),
|
285 |
-
(mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL),
|
286 |
-
('RealESRGAN_x2plus.pth', REALESRGAN_URL)
|
287 |
-
]:
|
288 |
-
download_model(url, filename, checkpoint_path)
|
289 |
-
|
290 |
-
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
291 |
-
|
292 |
-
def get_cfg_defaults(self):
|
293 |
-
return {
|
294 |
-
'MODEL': {
|
295 |
-
'CHECKPOINTS_DIR': '',
|
296 |
-
'CONFIG_DIR': '',
|
297 |
-
'DEVICE': self.device,
|
298 |
-
'SCALE': 64,
|
299 |
-
'NUM_VOXEL_FRAMES': 8,
|
300 |
-
'NUM_MOTION_FRAMES': 10,
|
301 |
-
'MAX_FEATURES': 256,
|
302 |
-
'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
|
303 |
-
'VIDEO_FPS': 25,
|
304 |
-
'OUTPUT_VIDEO_FPS': None,
|
305 |
-
'OUTPUT_AUDIO_SAMPLE_RATE': None,
|
306 |
-
'USE_ENHANCER': False,
|
307 |
-
'ENHANCER_NAME': '',
|
308 |
-
'BG_UPSAMPLER': None,
|
309 |
-
'IS_HALF': False
|
310 |
-
},
|
311 |
-
'INPUT_IMAGE': {}
|
312 |
-
}
|
313 |
-
|
314 |
-
def merge_from_file(self, filepath):
|
315 |
-
if os.path.exists(filepath):
|
316 |
-
with open(filepath, 'r') as f:
|
317 |
-
cfg_from_file = yaml.safe_load(f)
|
318 |
-
self.cfg.update(cfg_from_file)
|
319 |
-
|
320 |
-
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
321 |
-
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
322 |
-
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
323 |
-
tts_text=None, tts_lang='en'):
|
324 |
-
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
|
325 |
-
pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
326 |
-
length_of_audio, use_blink, result_dir, tts_text, tts_lang)
|
327 |
-
return self.sadtalker_model.save_result()
|
328 |
-
|
329 |
-
|
330 |
-
class SadTalkerModel:
|
331 |
-
|
332 |
-
def __init__(self, sadtalker_cfg, device_id=[0]):
|
333 |
-
self.cfg = sadtalker_cfg
|
334 |
-
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
335 |
-
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
336 |
-
self.preprocesser = self.sadtalker.preprocesser
|
337 |
-
self.kp_extractor = self.sadtalker.kp_extractor
|
338 |
-
self.generator = self.sadtalker.generator
|
339 |
-
self.mapping = self.sadtalker.mapping
|
340 |
-
self.he_estimator = self.sadtalker.he_estimator
|
341 |
-
self.audio_to_coeff = self.sadtalker.audio_to_coeff
|
342 |
-
self.animate_from_coeff = self.sadtalker.animate_from_coeff
|
343 |
-
self.face_enhancer = self.sadtalker.face_enhancer
|
344 |
-
|
345 |
-
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
346 |
-
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
347 |
-
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
348 |
-
tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
|
349 |
-
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
350 |
-
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
|
351 |
-
use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
|
352 |
-
jitter_amount, jitter_source_image)
|
353 |
-
return self.inner_test.test()
|
354 |
-
|
355 |
-
def save_result(self):
|
356 |
-
return self.inner_test.save_result()
|
357 |
-
|
358 |
-
|
359 |
-
class SadTalkerInner:
|
360 |
-
|
361 |
-
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
362 |
-
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
363 |
-
length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
|
364 |
-
self.sadtalker_model = sadtalker_model
|
365 |
-
self.source_image = source_image
|
366 |
-
self.driven_audio = driven_audio
|
367 |
-
self.preprocess = preprocess
|
368 |
-
self.still_mode = still_mode
|
369 |
-
self.use_enhancer = use_enhancer
|
370 |
-
self.batch_size = batch_size
|
371 |
-
self.size = size
|
372 |
-
self.pose_style = pose_style
|
373 |
-
self.exp_scale = exp_scale
|
374 |
-
self.use_ref_video = use_ref_video
|
375 |
-
self.ref_video = ref_video
|
376 |
-
self.ref_info = ref_info
|
377 |
-
self.use_idle_mode = use_idle_mode
|
378 |
-
self.length_of_audio = length_of_audio
|
379 |
-
self.use_blink = use_blink
|
380 |
-
self.result_dir = result_dir
|
381 |
-
self.tts_text = tts_text
|
382 |
-
self.tts_lang = tts_lang
|
383 |
-
self.jitter_amount = jitter_amount
|
384 |
-
self.jitter_source_image = jitter_source_image
|
385 |
-
self.device = self.sadtalker_model.device
|
386 |
-
self.output_path = None
|
387 |
-
|
388 |
-
def get_test_data(self):
|
389 |
-
proc = self.sadtalker_model.preprocesser
|
390 |
-
if self.tts_text is not None:
|
391 |
-
temp_dir = tempfile.mkdtemp()
|
392 |
-
audio_path = os.path.join(temp_dir, 'audio.wav')
|
393 |
-
tts = TTSTalker()
|
394 |
-
tts.test(self.tts_text, self.tts_lang)
|
395 |
-
self.driven_audio = audio_path
|
396 |
-
source_image_pil = Image.open(self.source_image).convert('RGB')
|
397 |
-
if self.jitter_source_image:
|
398 |
-
jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
399 |
-
jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
400 |
-
source_image_pil = Image.fromarray(
|
401 |
-
np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
|
402 |
-
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
|
403 |
-
if self.still_mode or self.use_idle_mode:
|
404 |
-
ref_pose_coeff = proc.generate_still_pose(self.pose_style)
|
405 |
-
ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
|
406 |
-
else:
|
407 |
-
ref_pose_coeff = None
|
408 |
-
ref_expression_coeff = None
|
409 |
-
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
|
410 |
-
self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
|
411 |
-
batch = {
|
412 |
-
'source_image': source_image_tensor.unsqueeze(0).to(self.device),
|
413 |
-
'audio': audio_tensor.unsqueeze(0).to(self.device),
|
414 |
-
'ref_pose_coeff': ref_pose_coeff,
|
415 |
-
'ref_expression_coeff': ref_expression_coeff,
|
416 |
-
'source_image_crop': cropped_image,
|
417 |
-
'crop_info': crop_info,
|
418 |
-
'use_blink': self.use_blink,
|
419 |
-
'pose_style': self.pose_style,
|
420 |
-
'exp_scale': self.exp_scale,
|
421 |
-
'ref_video': self.ref_video,
|
422 |
-
'use_ref_video': self.use_ref_video,
|
423 |
-
'ref_info': self.ref_info,
|
424 |
-
}
|
425 |
-
return batch, audio_sample_rate
|
426 |
-
|
427 |
-
def run_inference(self, batch):
|
428 |
-
kp_extractor = self.sadtalker_model.kp_extractor
|
429 |
-
generator = self.sadtalker_model.generator
|
430 |
-
mapping = self.sadtalker_model.mapping
|
431 |
-
he_estimator = self.sadtalker_model.he_estimator
|
432 |
-
audio_to_coeff = self.sadtalker_model.audio_to_coeff
|
433 |
-
animate_from_coeff = self.sadtalker_model.animate_from_coeff
|
434 |
-
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
|
435 |
-
with torch.no_grad():
|
436 |
-
kp_source = kp_extractor(batch['source_image'])
|
437 |
-
if self.still_mode or self.use_idle_mode:
|
438 |
-
ref_pose_coeff = batch['ref_pose_coeff']
|
439 |
-
ref_expression_coeff = batch['ref_expression_coeff']
|
440 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
441 |
-
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
442 |
-
elif self.use_idle_mode:
|
443 |
-
ref_pose_coeff = batch['ref_pose_coeff']
|
444 |
-
ref_expression_coeff = batch['ref_expression_coeff']
|
445 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
446 |
-
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
447 |
-
else:
|
448 |
-
if self.use_ref_video:
|
449 |
-
kp_ref = kp_extractor(batch['source_image'])
|
450 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
|
451 |
-
use_ref_info=batch['ref_info'])
|
452 |
-
else:
|
453 |
-
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
|
454 |
-
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
|
455 |
-
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
|
456 |
-
if self.use_blink:
|
457 |
-
coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
|
458 |
-
else:
|
459 |
-
coeff['blink_coeff'] = None
|
460 |
-
kp_driving = audio_to_coeff(batch['audio'])[0]
|
461 |
-
kp_norm = animate_from_coeff.normalize_kp(kp_driving)
|
462 |
-
coeff['kp_driving'] = kp_norm
|
463 |
-
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
|
464 |
-
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
|
465 |
-
he_estimator, batch['audio'], batch['source_image_crop'],
|
466 |
-
face_enhancer=face_enhancer)
|
467 |
-
return output_video
|
468 |
-
|
469 |
-
def post_processing(self, output_video, audio_sample_rate, batch):
|
470 |
-
proc = self.sadtalker_model.preprocesser
|
471 |
-
base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
|
472 |
-
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
|
473 |
-
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
|
474 |
-
self.output_path = output_video_path
|
475 |
-
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
|
476 |
-
'OUTPUT_VIDEO_FPS'] is None else \
|
477 |
-
self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
|
478 |
-
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
|
479 |
-
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
|
480 |
-
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
|
481 |
-
if self.use_enhancer:
|
482 |
-
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
|
483 |
-
save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
|
484 |
-
paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
|
485 |
-
output_video_path)
|
486 |
-
os.remove(enhanced_path)
|
487 |
-
else:
|
488 |
-
save_video_with_watermark(output_video, self.driven_audio, output_video_path)
|
489 |
-
if self.tts_text is not None:
|
490 |
-
shutil.rmtree(os.path.dirname(self.driven_audio))
|
491 |
-
|
492 |
-
def save_result(self):
|
493 |
-
return self.output_path
|
494 |
-
|
495 |
-
def __call__(self):
|
496 |
-
return self.output_path
|
497 |
-
|
498 |
-
def test(self):
|
499 |
-
batch, audio_sample_rate = self.get_test_data()
|
500 |
-
output_video = self.run_inference(batch)
|
501 |
-
self.post_processing(output_video, audio_sample_rate, batch)
|
502 |
-
return self.save_result()
|
503 |
-
|
504 |
-
|
505 |
-
class SadTalkerInnerModel:
|
506 |
-
|
507 |
-
def __init__(self, sadtalker_cfg, device_id=[0]):
|
508 |
-
self.cfg = sadtalker_cfg
|
509 |
-
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
510 |
-
self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
|
511 |
-
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
|
512 |
-
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
|
513 |
-
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
|
514 |
-
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
|
515 |
-
'USE_ENHANCER'] else None
|
516 |
-
self.generator = Generator(sadtalker_cfg, self.device)
|
517 |
-
self.mapping = Mapping(sadtalker_cfg, self.device)
|
518 |
-
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
|
519 |
-
|
520 |
-
|
521 |
-
class Preprocesser:
|
522 |
-
|
523 |
-
def __init__(self, sadtalker_cfg, device):
|
524 |
-
self.cfg = sadtalker_cfg
|
525 |
-
self.device = device
|
526 |
-
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
527 |
-
self.mouth_detector = MouthDetector()
|
528 |
-
|
529 |
-
def crop(self, source_image_pil, preprocess_type, size=256):
|
530 |
-
source_image = np.array(source_image_pil)
|
531 |
-
face_info = self.face3d_helper.run(source_image)
|
532 |
-
if face_info is None:
|
533 |
-
raise Exception("No face detected")
|
534 |
-
x_min, y_min, x_max, y_max = face_info[:4]
|
535 |
-
old_size = (x_max - x_min, y_max - y_min)
|
536 |
-
x_center = (x_max + x_min) / 2
|
537 |
-
y_center = (y_max + y_min) / 2
|
538 |
-
if preprocess_type == 'crop':
|
539 |
-
face_size = max(x_max - x_min, y_max - y_min)
|
540 |
-
x_min = int(x_center - face_size / 2)
|
541 |
-
y_min = int(y_center - face_size / 2)
|
542 |
-
x_max = int(x_center + face_size / 2)
|
543 |
-
y_max = int(y_center + face_size / 2)
|
544 |
-
else:
|
545 |
-
x_min -= int((x_max - x_min) * 0.1)
|
546 |
-
y_min -= int((y_max - y_min) * 0.1)
|
547 |
-
x_max += int((x_max - x_min) * 0.1)
|
548 |
-
y_max += int((y_max - y_min) * 0.1)
|
549 |
-
h, w = source_image.shape[:2]
|
550 |
-
x_min = max(0, x_min)
|
551 |
-
y_min = max(0, y_min)
|
552 |
-
x_max = min(w, x_max)
|
553 |
-
y_max = min(h, y_max)
|
554 |
-
cropped_image = source_image[y_min:y_max, x_min:x_max]
|
555 |
-
cropped_image_pil = Image.fromarray(cropped_image)
|
556 |
-
if size is not None and size != 0:
|
557 |
-
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
|
558 |
-
source_image_tensor = self.img2tensor(cropped_image_pil)
|
559 |
-
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
|
560 |
-
self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
|
561 |
-
|
562 |
-
def img2tensor(self, img):
|
563 |
-
img = np.array(img).astype(np.float32) / 255.0
|
564 |
-
img = np.transpose(img, (2, 0, 1))
|
565 |
-
return torch.FloatTensor(img)
|
566 |
-
|
567 |
-
def video_to_tensor(self, video, device):
|
568 |
-
video_tensor_list = []
|
569 |
-
import torchvision.transforms as transforms
|
570 |
-
transform_func = transforms.ToTensor()
|
571 |
-
for frame in video:
|
572 |
-
frame_pil = Image.fromarray(frame)
|
573 |
-
frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
|
574 |
-
video_tensor_list.append(frame_tensor)
|
575 |
-
video_tensor = torch.cat(video_tensor_list, dim=0)
|
576 |
-
return video_tensor
|
577 |
-
|
578 |
-
def process_audio(self, audio_path, sample_rate):
|
579 |
-
wav = load_wav_util(audio_path, sample_rate)
|
580 |
-
wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
|
581 |
-
return wav_tensor, sample_rate
|
582 |
-
|
583 |
-
def generate_still_pose(self, pose_style):
|
584 |
-
ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
585 |
-
ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
|
586 |
-
return ref_pose_coeff
|
587 |
-
|
588 |
-
def generate_still_expression(self, exp_scale):
|
589 |
-
ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
590 |
-
ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
|
591 |
-
return ref_expression_coeff
|
592 |
-
|
593 |
-
def generate_idles_pose(self, length_of_audio, pose_style):
|
594 |
-
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
595 |
-
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
596 |
-
start_pose = self.generate_still_pose(pose_style)
|
597 |
-
end_pose = self.generate_still_pose(pose_style)
|
598 |
-
for frame_idx in range(num_frames):
|
599 |
-
alpha = frame_idx / num_frames
|
600 |
-
ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
|
601 |
-
return ref_pose_coeff
|
602 |
-
|
603 |
-
def generate_idles_expression(self, length_of_audio):
|
604 |
-
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
605 |
-
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
606 |
-
start_exp = self.generate_still_expression(1.0)
|
607 |
-
end_exp = self.generate_still_expression(1.0)
|
608 |
-
for frame_idx in range(num_frames):
|
609 |
-
alpha = frame_idx / num_frames
|
610 |
-
ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
|
611 |
-
return ref_expression_coeff
|
612 |
-
|
613 |
-
|
614 |
-
class KeyPointExtractor(nn.Module):
|
615 |
-
|
616 |
-
def __init__(self, sadtalker_cfg, device):
|
617 |
-
super(KeyPointExtractor, self).__init__()
|
618 |
-
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
|
619 |
-
num_kp=10,
|
620 |
-
num_dilation_blocks=2,
|
621 |
-
dropout_rate=0.1).to(device)
|
622 |
-
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
|
623 |
-
load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
|
624 |
-
|
625 |
-
def forward(self, x):
|
626 |
-
kp = self.kp_extractor(x)
|
627 |
-
return kp
|
628 |
-
|
629 |
-
|
630 |
-
class Audio2Coeff(nn.Module):
|
631 |
-
|
632 |
-
def __init__(self, sadtalker_cfg, device):
|
633 |
-
super(Audio2Coeff, self).__init__()
|
634 |
-
self.audio_model = Wav2Vec2Model().to(device)
|
635 |
-
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
|
636 |
-
load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
|
637 |
-
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
638 |
-
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
639 |
-
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
640 |
-
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth')
|
641 |
-
load_state_dict_robust(self, mapping_checkpoint, device)
|
642 |
-
|
643 |
-
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
|
644 |
-
audio_embedding = self.audio_model(audio_tensor)
|
645 |
-
pose_coeff = self.pose_mapper(audio_embedding)
|
646 |
-
if ref_pose_coeff is not None:
|
647 |
-
pose_coeff = ref_pose_coeff
|
648 |
-
if kp_ref is not None and use_ref_info == 'pose':
|
649 |
-
ref_pose_6d = kp_ref['value'][:, :6]
|
650 |
-
pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
|
651 |
-
return pose_coeff
|
652 |
-
|
653 |
-
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
|
654 |
-
audio_embedding = self.audio_model(audio_tensor)
|
655 |
-
expression_coeff = self.exp_mapper(audio_embedding)
|
656 |
-
if ref_expression_coeff is not None:
|
657 |
-
expression_coeff = ref_expression_coeff
|
658 |
-
return expression_coeff
|
659 |
-
|
660 |
-
def get_blink_coeff(self, audio_tensor):
|
661 |
-
audio_embedding = self.audio_model(audio_tensor)
|
662 |
-
blink_coeff = self.blink_mapper(audio_embedding)
|
663 |
-
return blink_coeff
|
664 |
-
|
665 |
-
def forward(self, audio):
|
666 |
-
audio_embedding = self.audio_model(audio)
|
667 |
-
pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
|
668 |
-
audio_embedding), self.blink_mapper(audio_embedding)
|
669 |
-
return pose_coeff, expression_coeff, blink_coeff
|
670 |
-
|
671 |
-
def mean_std_normalize(self, coeff):
|
672 |
-
mean = coeff.mean(dim=1, keepdim=True)
|
673 |
-
std = coeff.std(dim=1, keepdim=True)
|
674 |
-
return (coeff - mean) / std
|
675 |
-
|
676 |
-
|
677 |
-
class AnimateFromCoeff(nn.Module):
|
678 |
-
|
679 |
-
def __init__(self, sadtalker_cfg, device):
|
680 |
-
super(AnimateFromCoeff, self).__init__()
|
681 |
-
self.generator = Generator(sadtalker_cfg, device)
|
682 |
-
self.mapping = Mapping(sadtalker_cfg, device)
|
683 |
-
self.kp_norm = KeypointNorm(device=device)
|
684 |
-
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
|
685 |
-
|
686 |
-
def normalize_kp(self, kp_driving):
|
687 |
-
return self.kp_norm(kp_driving)
|
688 |
-
|
689 |
-
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
|
690 |
-
face_enhancer=None):
|
691 |
-
kp_driving = coeff['kp_driving']
|
692 |
-
jacobian = coeff['jacobian']
|
693 |
-
pose_coeff = coeff['pose_coeff']
|
694 |
-
expression_coeff = coeff['expression_coeff']
|
695 |
-
blink_coeff = coeff['blink_coeff']
|
696 |
-
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff)
|
697 |
-
sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
|
698 |
-
dense_motion = sparse_motion['dense_motion']
|
699 |
-
video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
|
700 |
-
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d)
|
701 |
-
video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
|
702 |
-
if face_enhancer is not None:
|
703 |
-
video_output_enhanced = []
|
704 |
-
for frame in tqdm(video_output, 'Face enhancer running'):
|
705 |
-
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
706 |
-
enhanced_image = face_enhancer.forward(np.array(pil_image))
|
707 |
-
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
|
708 |
-
video_output = video_output_enhanced
|
709 |
-
return video_output
|
710 |
-
|
711 |
-
def make_animation(self, video_array):
|
712 |
-
H, W, _ = video_array[0].shape
|
713 |
-
out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
|
714 |
-
for img in video_array:
|
715 |
-
out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
716 |
-
out.release()
|
717 |
-
video = imageio.mimread('./tmp.mp4')
|
718 |
-
os.remove('./tmp.mp4')
|
719 |
-
return video
|
720 |
-
|
721 |
-
|
722 |
-
class Generator(nn.Module):
|
723 |
-
|
724 |
-
def __init__(self, sadtalker_cfg, device):
|
725 |
-
super(Generator, self).__init__()
|
726 |
-
self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
|
727 |
-
num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
|
728 |
-
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
|
729 |
-
num_channels=3,
|
730 |
-
kp_size=10,
|
731 |
-
num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
|
732 |
-
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
|
733 |
-
load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
|
734 |
-
|
735 |
-
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
|
736 |
-
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param)
|
737 |
-
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
|
738 |
-
|
739 |
-
|
740 |
-
class Mapping(nn.Module):
|
741 |
-
|
742 |
-
def __init__(self, sadtalker_cfg, device):
|
743 |
-
super(Mapping, self).__init__()
|
744 |
-
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
745 |
-
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
|
746 |
-
load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
|
747 |
-
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
748 |
-
|
749 |
-
def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
|
750 |
-
coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
|
751 |
-
face_3d = self.mapping_net(coeff) + self.f_3d_mean
|
752 |
-
if blink_coeff is not None:
|
753 |
-
face_3d[:, -1:] = blink_coeff
|
754 |
-
return face_3d
|
755 |
-
|
756 |
-
|
757 |
-
class OcclusionAwareDenseMotion(nn.Module):
|
758 |
-
|
759 |
-
def __init__(self, sadtalker_cfg, device):
|
760 |
-
super(OcclusionAwareDenseMotion, self).__init__()
|
761 |
-
self.dense_motion_network = DenseMotionNetwork(num_kp=10,
|
762 |
-
num_channels=3,
|
763 |
-
block_expansion=sadtalker_cfg['MODEL']['SCALE'],
|
764 |
-
num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
|
765 |
-
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
|
766 |
-
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
|
767 |
-
load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
|
768 |
-
|
769 |
-
def forward(self, kp_source, kp_driving, jacobian):
|
770 |
-
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
|
771 |
-
return sparse_motion
|
772 |
-
|
773 |
-
|
774 |
-
class FaceEnhancer(nn.Module):
|
775 |
-
|
776 |
-
def __init__(self, sadtalker_cfg, device):
|
777 |
-
super(FaceEnhancer, self).__init__()
|
778 |
-
enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
|
779 |
-
bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
|
780 |
-
if enhancer_name == 'gfpgan':
|
781 |
-
from gfpgan import GFPGANer
|
782 |
-
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
|
783 |
-
upscale=1,
|
784 |
-
arch='clean',
|
785 |
-
channel_multiplier=2,
|
786 |
-
bg_upsampler=bg_upsampler)
|
787 |
-
elif enhancer_name == 'realesrgan':
|
788 |
-
from realesrgan import RealESRGANer
|
789 |
-
half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
|
790 |
-
self.face_enhancer = RealESRGANer(scale=2,
|
791 |
-
model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
|
792 |
-
'RealESRGAN_x2plus.pth'),
|
793 |
-
tile=0,
|
794 |
-
tile_pad=10,
|
795 |
-
pre_pad=0,
|
796 |
-
half=half,
|
797 |
-
device=device)
|
798 |
-
else:
|
799 |
-
self.face_enhancer = None
|
800 |
-
|
801 |
-
def forward(self, x):
|
802 |
-
if self.face_enhancer:
|
803 |
-
return self.face_enhancer.enhance(x, outscale=1)[0]
|
804 |
-
return x
|
805 |
-
|
806 |
-
|
807 |
-
def load_models():
|
808 |
-
checkpoint_path = './checkpoints'
|
809 |
-
config_path = './src/config'
|
810 |
-
size = 256
|
811 |
-
preprocess = 'crop'
|
812 |
-
old_version = False
|
813 |
-
|
814 |
-
sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version)
|
815 |
-
print("SadTalker models loaded successfully!")
|
816 |
-
return sadtalker_instance
|
817 |
-
|
818 |
-
|
819 |
-
if __name__ == '__main__':
|
820 |
-
sadtalker_instance = load_models()
|
|
|
1 |
+
import os, shutil, uuid, cv2, numpy as np, torch, torch.nn as nn, torch.nn.functional as F, yaml, safetensors, librosa, imageio
|
2 |
+
from PIL import Image
|
3 |
+
from skimage import img_as_ubyte, transform
|
4 |
+
from scipy.io import loadmat, wavfile
|
5 |
+
|
6 |
+
class SadTalker():
|
7 |
+
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', old_version=False):
|
8 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
self.cfg = self.get_cfg_defaults()
|
10 |
+
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
|
11 |
+
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
|
12 |
+
self.cfg['MODEL']['CONFIG_DIR'] = config_path
|
13 |
+
self.cfg['MODEL']['DEVICE'] = self.device
|
14 |
+
self.cfg['INPUT_IMAGE'] = {}
|
15 |
+
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
|
16 |
+
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
|
17 |
+
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
|
18 |
+
self.cfg['INPUT_IMAGE']['SIZE'] = size
|
19 |
+
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
|
20 |
+
for filename, url in [(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url), (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL), ('RealESRGAN_x2plus.pth', REALESRGAN_URL)]: download_model(url, filename, checkpoint_dir)
|
21 |
+
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
22 |
+
|
23 |
+
def get_cfg_defaults(self):
|
24 |
+
return {'MODEL': {'CHECKPOINTS_DIR': '', 'CONFIG_DIR': '', 'DEVICE': self.device, 'SCALE': 64, 'NUM_VOXEL_FRAMES': 8, 'NUM_MOTION_FRAMES': 10, 'MAX_FEATURES': 256, 'DRIVEN_AUDIO_SAMPLE_RATE': 16000, 'VIDEO_FPS': 25, 'OUTPUT_VIDEO_FPS': None, 'OUTPUT_AUDIO_SAMPLE_RATE': None, 'USE_ENHANCER': False, 'ENHANCER_NAME': '', 'BG_UPSAMPLER': None, 'IS_HALF': False}, 'INPUT_IMAGE': {}}
|
25 |
+
|
26 |
+
def merge_from_file(self, filepath):
|
27 |
+
if os.path.exists(filepath):
|
28 |
+
with open(filepath, 'r') as f: cfg_from_file = yaml.safe_load(f); self.cfg.update(cfg_from_file)
|
29 |
+
|
30 |
+
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en'):
|
31 |
+
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang); return self.sadtalker_model.save_result()
|
32 |
+
|
33 |
+
class SadTalkerModel():
|
34 |
+
def __init__(self, sadtalker_cfg, device_id=[0]):
|
35 |
+
self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
36 |
+
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
37 |
+
self.preprocesser = self.sadtalker.preprocesser
|
38 |
+
self.kp_extractor = self.sadtalker.kp_extractor; self.generator = self.sadtalker.generator
|
39 |
+
self.mapping = self.sadtalker.mapping; self.he_estimator = self.sadtalker.he_estimator
|
40 |
+
self.audio_to_coeff = self.sadtalker.audio_to_coeff; self.animate_from_coeff = self.sadtalker.animate_from_coeff; self.face_enhancer = self.sadtalker.face_enhancer
|
41 |
+
|
42 |
+
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
|
43 |
+
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image); return self.inner_test.test()
|
44 |
+
|
45 |
+
def save_result(self):
|
46 |
+
return self.inner_test.save_result()
|
47 |
+
|
48 |
+
class SadTalkerInner():
|
49 |
+
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
|
50 |
+
self.sadtalker_model = sadtalker_model; self.source_image = source_image; self.driven_audio = driven_audio
|
51 |
+
self.preprocess = preprocess; self.still_mode = still_mode; self.use_enhancer = use_enhancer
|
52 |
+
self.batch_size = batch_size; self.size = size; self.pose_style = pose_style; self.exp_scale = exp_scale
|
53 |
+
self.use_ref_video = use_ref_video; self.ref_video = ref_video; self.ref_info = ref_info
|
54 |
+
self.use_idle_mode = use_idle_mode; self.length_of_audio = length_of_audio; self.use_blink = use_blink
|
55 |
+
self.result_dir = result_dir; self.tts_text = tts_text; self.tts_lang = tts_lang
|
56 |
+
self.jitter_amount = jitter_amount; self.jitter_source_image = jitter_source_image; self.device = self.sadtalker_model.device; self.output_path = None
|
57 |
+
|
58 |
+
def get_test_data(self):
|
59 |
+
proc = self.sadtalker_model.preprocesser
|
60 |
+
if self.tts_text is not None: temp_dir = tempfile.mkdtemp(); audio_path = os.path.join(temp_dir, 'audio.wav'); tts = TTSTalker(); tts.test(self.tts_text, self.tts_lang); self.driven_audio = audio_path
|
61 |
+
source_image_pil = Image.open(self.source_image).convert('RGB')
|
62 |
+
if self.jitter_source_image: jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); source_image_pil = Image.fromarray(np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
|
63 |
+
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
|
64 |
+
if self.still_mode or self.use_idle_mode: ref_pose_coeff = proc.generate_still_pose(self.pose_style); ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
|
65 |
+
else: ref_pose_coeff = None; ref_expression_coeff = None
|
66 |
+
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
|
67 |
+
batch = {'source_image': source_image_tensor.unsqueeze(0).to(self.device), 'audio': audio_tensor.unsqueeze(0).to(self.device), 'ref_pose_coeff': ref_pose_coeff, 'ref_expression_coeff': ref_expression_coeff, 'source_image_crop': cropped_image, 'crop_info': crop_info, 'use_blink': self.use_blink, 'pose_style': self.pose_style, 'exp_scale': self.exp_scale, 'ref_video': self.ref_video, 'use_ref_video': self.use_ref_video, 'ref_info': self.ref_info}
|
68 |
+
return batch, audio_sample_rate
|
69 |
+
|
70 |
+
def run_inference(self, batch):
|
71 |
+
kp_extractor, generator, mapping, he_estimator, audio_to_coeff, animate_from_coeff, face_enhancer = self.sadtalker_model.kp_extractor, self.sadtalker_model.generator, self.sadtalker_model.mapping, self.sadtalker_model.he_estimator, self.sadtalker_model.audio_to_coeff, self.sadtalker_model.animate_from_coeff, self.sadtalker_model.face_enhancer
|
72 |
+
with torch.no_grad():
|
73 |
+
kp_source = kp_extractor(batch['source_image'])
|
74 |
+
if self.still_mode or self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
|
75 |
+
elif self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
|
76 |
+
else:
|
77 |
+
if self.use_ref_video: kp_ref = kp_extractor(batch['source_image']); pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, use_ref_info=batch['ref_info'])
|
78 |
+
else: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
|
79 |
+
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
|
80 |
+
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
|
81 |
+
if self.use_blink: coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
|
82 |
+
else: coeff['blink_coeff'] = None
|
83 |
+
kp_driving = audio_to_coeff(batch['audio'])[0]; kp_norm = animate_from_coeff.normalize_kp(kp_driving); coeff['kp_driving'] = kp_norm; coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
|
84 |
+
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, he_estimator, batch['audio'], batch['source_image_crop'], face_enhancer=face_enhancer)
|
85 |
+
return output_video
|
86 |
+
|
87 |
+
def post_processing(self, output_video, audio_sample_rate, batch):
|
88 |
+
proc = self.sadtalker_model.preprocesser; base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]; audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
|
89 |
+
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4'); self.output_path = output_video_path
|
90 |
+
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
|
91 |
+
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
|
92 |
+
if self.use_enhancer: enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4'); save_video_with_watermark(output_video, self.driven_audio, enhanced_path); paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, output_video_path); os.remove(enhanced_path)
|
93 |
+
else: save_video_with_watermark(output_video, self.driven_audio, output_video_path)
|
94 |
+
if self.tts_text is not None: shutil.rmtree(os.path.dirname(self.driven_audio))
|
95 |
+
|
96 |
+
def save_result(self):
|
97 |
+
return self.output_path
|
98 |
+
|
99 |
+
def __call__(self):
|
100 |
+
return self.output_path
|
101 |
+
|
102 |
+
def test(self):
|
103 |
+
batch, audio_sample_rate = self.get_test_data(); output_video = self.run_inference(batch); self.post_processing(output_video, audio_sample_rate, batch); return self.save_result()
|
104 |
+
|
105 |
+
class SadTalkerInnerModel():
|
106 |
+
def __init__(self, sadtalker_cfg, device_id=[0]):
|
107 |
+
self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
108 |
+
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
109 |
+
self.preprocesser = Preprocesser(sadtalker_cfg, self.device); self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
|
110 |
+
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device); self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
|
111 |
+
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL']['USE_ENHANCER'] else None
|
112 |
+
self.generator = Generator(sadtalker_cfg, self.device); self.mapping = Mapping(sadtalker_cfg, self.device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
|
113 |
+
|
114 |
+
class Preprocesser():
|
115 |
+
def __init__(self, sadtalker_cfg, device):
|
116 |
+
self.cfg = sadtalker_cfg; self.device = device
|
117 |
+
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device); self.mouth_detector = MouthDetector()
|
118 |
+
|
119 |
+
def crop(self, source_image_pil, preprocess_type, size=256):
|
120 |
+
source_image = np.array(source_image_pil); face_info = self.face3d_helper.run(source_image)
|
121 |
+
if face_info is None: raise Exception("No face detected")
|
122 |
+
x_min, y_min, x_max, y_max = face_info[:4]; old_size = (x_max - x_min, y_max - y_min); x_center = (x_max + x_min) / 2; y_center = (y_max + y_min) / 2
|
123 |
+
if preprocess_type == 'crop': face_size = max(x_max - x_min, y_max - y_min); x_min = int(x_center - face_size / 2); y_min = int(y_center - face_size / 2); x_max = int(x_center + face_size / 2); y_max = int(y_center + face_size / 2)
|
124 |
+
else: x_min -= int((x_max - x_min) * 0.1); y_min -= int((y_max - y_min) * 0.1); x_max += int((x_max - x_min) * 0.1); y_max += int((y_max - y_min) * 0.1)
|
125 |
+
h, w = source_image.shape[:2]; x_min = max(0, x_min); y_min = max(0, y_min); x_max = min(w, x_max); y_max = min(h, y_max)
|
126 |
+
cropped_image = source_image[y_min:y_max, x_min:x_max]; cropped_image_pil = Image.fromarray(cropped_image)
|
127 |
+
if size is not None and size != 0: cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
|
128 |
+
source_image_tensor = self.img2tensor(cropped_image_pil); return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
|
129 |
+
|
130 |
+
def img2tensor(self, img):
|
131 |
+
img = np.array(img).astype(np.float32) / 255.0; img = np.transpose(img, (2, 0, 1)); return torch.FloatTensor(img)
|
132 |
+
def video_to_tensor(self, video, device): return 0
|
133 |
+
def process_audio(self, audio_path, sample_rate): wav = load_wav_util(audio_path, sample_rate); wav_tensor = torch.FloatTensor(wav).unsqueeze(0); return wav_tensor, sample_rate
|
134 |
+
def generate_still_pose(self, pose_style): ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32); return ref_pose_coeff
|
135 |
+
def generate_still_expression(self, exp_scale): ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32); return ref_expression_coeff
|
136 |
+
def generate_idles_pose(self, length_of_audio, pose_style): return 0
|
137 |
+
def generate_idles_expression(self, length_of_audio): return 0
|
138 |
+
|
139 |
+
class KeyPointExtractor(nn.Module):
|
140 |
+
def __init__(self, sadtalker_cfg, device):
|
141 |
+
super(KeyPointExtractor, self).__init__(); self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'], num_kp=10, num_dilation_blocks=2, dropout_rate=0.1).to(device)
|
142 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors'); load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
|
143 |
+
def forward(self, x): kp = self.kp_extractor(x); return kp
|
144 |
+
|
145 |
+
class Audio2Coeff(nn.Module):
|
146 |
+
def __init__(self, sadtalker_cfg, device):
|
147 |
+
super(Audio2Coeff, self).__init__(); self.audio_model = Wav2Vec2Model().to(device)
|
148 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth'); load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
|
149 |
+
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
150 |
+
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth'); load_state_dict_robust(self, mapping_checkpoint, device)
|
151 |
+
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): audio_embedding = self.audio_model(audio_tensor); pose_coeff = self.pose_mapper(audio_embedding)
|
152 |
+
if ref_pose_coeff is not None: pose_coeff = ref_pose_coeff
|
153 |
+
if kp_ref is not None and use_ref_info == 'pose': ref_pose_6d = kp_ref['value'][:, :6]; pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
|
154 |
+
return pose_coeff
|
155 |
+
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): audio_embedding = self.audio_model(audio_tensor); expression_coeff = self.exp_mapper(audio_embedding)
|
156 |
+
if ref_expression_coeff is not None: expression_coeff = ref_expression_coeff; return expression_coeff
|
157 |
+
def get_blink_coeff(self, audio_tensor): audio_embedding = self.audio_model(audio_tensor); blink_coeff = self.blink_mapper(audio_embedding); return blink_coeff
|
158 |
+
def forward(self, audio): audio_embedding = self.audio_model(audio); pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(audio_embedding), self.blink_mapper(audio_embedding); return pose_coeff, expression_coeff, blink_coeff
|
159 |
+
def mean_std_normalize(self, coeff): mean = coeff.mean(dim=1, keepdim=True); std = coeff.std(dim=1, keepdim=True); return (coeff - mean) / std
|
160 |
+
|
161 |
+
class AnimateFromCoeff(nn.Module):
|
162 |
+
def __init__(self, sadtalker_cfg, device):
|
163 |
+
super(AnimateFromCoeff, self).__init__(); self.generator = Generator(sadtalker_cfg, device); self.mapping = Mapping(sadtalker_cfg, device); self.kp_norm = KeypointNorm(device=device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
|
164 |
+
def normalize_kp(self, kp_driving): return self.kp_norm(kp_driving)
|
165 |
+
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, face_enhancer=None):
|
166 |
+
kp_driving, jacobian, pose_coeff, expression_coeff, blink_coeff = coeff['kp_driving'], coeff['jacobian'], coeff['pose_coeff'], coeff['expression_coeff'], coeff['blink_coeff']
|
167 |
+
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff); sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
|
168 |
+
dense_motion = sparse_motion['dense_motion']; video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
|
169 |
+
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d); video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
|
170 |
+
if face_enhancer is not None: video_output_enhanced = []; for frame in tqdm(video_output, 'Face enhancer running'): pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)); enhanced_image = face_enhancer.forward(np.array(pil_image)); video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)); video_output = video_output_enhanced
|
171 |
+
return video_output
|
172 |
+
def make_animation(self, video_array): H, W, _ = video_array[0].shape; out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)); for img in video_array: out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)); out.release(); video = imageio.mimread('./tmp.mp4'); os.remove('./tmp.mp4'); return video
|
173 |
+
|
174 |
+
class Generator(nn.Module):
|
175 |
+
def __init__(self, sadtalker_cfg, device):
|
176 |
+
super(Generator, self).__init__(); self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'], max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'], num_channels=3, kp_size=10, num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
|
177 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth'); load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
|
178 |
+
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param); return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
|
179 |
+
|
180 |
+
class Mapping(nn.Module):
|
181 |
+
def __init__(self, sadtalker_cfg, device):
|
182 |
+
super(Mapping, self).__init__(); self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
183 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth'); load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
|
184 |
+
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
185 |
+
def forward(self, expression_coeff, pose_coeff, blink_coeff=None): coeff = torch.cat([expression_coeff, pose_coeff], dim=1); face_3d = self.mapping_net(coeff) + self.f_3d_mean; if blink_coeff is not None: face_3d[:, -1:] = blink_coeff; return face_3d
|
186 |
+
|
187 |
+
class OcclusionAwareDenseMotion(nn.Module):
|
188 |
+
def __init__(self, sadtalker_cfg, device):
|
189 |
+
super(OcclusionAwareDenseMotion, self).__init__(); self.dense_motion_network = DenseMotionNetwork(num_kp=10, num_channels=3, block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1, max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
|
190 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth'); load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
|
191 |
+
def forward(self, kp_source, kp_driving, jacobian): sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian); return sparse_motion
|
192 |
+
|
193 |
+
class FaceEnhancer(nn.Module):
|
194 |
+
def __init__(self, sadtalker_cfg, device):
|
195 |
+
super(FaceEnhancer, self).__init__(); enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']; bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
|
196 |
+
if enhancer_name == 'gfpgan': from gfpgan import GFPGANer; self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler)
|
197 |
+
elif enhancer_name == 'realesrgan': from realesrgan import RealESRGANer; half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']; self.face_enhancer = RealESRGANer(scale=2, model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'RealESRGAN_x2plus.pth'), tile=0, tile_pad=10, pre_pad=0, half=half, device=device)
|
198 |
+
else: self.face_enhancer = None
|
199 |
+
def forward(self, x): return self.face_enhancer.enhance(x, outscale=1)[0] if self.face_enhancer else x
|
200 |
+
|
201 |
+
def download_model(url, filename, checkpoint_dir):
|
202 |
+
if not os.path.exists(os.path.join(checkpoint_dir, filename)): print(f"Downloading {filename}..."); os.makedirs(checkpoint_dir, exist_ok=True); urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)); print(f"{filename} downloaded.")
|
203 |
+
else: print(f"{filename} already exists.")
|
204 |
+
|
205 |
+
def load_models():
|
206 |
+
checkpoint_path = './checkpoints'; config_path = './src/config'; size = 256; preprocess = 'crop'; old_version = False
|
207 |
+
sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version); print("SadTalker models loaded successfully!"); return sadtalker_instance
|
208 |
+
|
209 |
+
if __name__ == '__main__': sadtalker_instance = load_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentiment_api.py
CHANGED
@@ -3,27 +3,14 @@ from main import *
|
|
3 |
import torch
|
4 |
|
5 |
def analyze_sentiment(text):
|
6 |
-
if sentiment_model is None:
|
7 |
-
|
8 |
-
|
9 |
-
features = [ord(c) for c in text[:10]]
|
10 |
-
while len(features) < 10:
|
11 |
-
features.append(0)
|
12 |
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
|
13 |
-
|
14 |
-
with torch.no_grad():
|
15 |
-
output = sentiment_model(features_tensor)
|
16 |
-
sentiment_idx = torch.argmax(output, dim=1).item()
|
17 |
-
sentiment_label = "positive" if sentiment_idx == 1 else "negative"
|
18 |
-
|
19 |
return {"sentiment": sentiment_label}
|
20 |
|
21 |
-
def sentiment_api():
|
22 |
-
data = request.get_json()
|
23 |
-
text = data.get('text')
|
24 |
-
if not text:
|
25 |
-
return jsonify({"error": "Text is required"}), 400
|
26 |
output = analyze_sentiment(text)
|
27 |
-
if "error" in output:
|
28 |
-
|
29 |
-
return jsonify(output)
|
|
|
3 |
import torch
|
4 |
|
5 |
def analyze_sentiment(text):
|
6 |
+
if sentiment_model is None: return {"error": "Sentiment model not initialized."}
|
7 |
+
features = [ord(c) for c in text[:10]];
|
8 |
+
while len(features) < 10: features.append(0)
|
|
|
|
|
|
|
9 |
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
|
10 |
+
with torch.no_grad(): output = sentiment_model(features_tensor); sentiment_idx = torch.argmax(output, dim=1).item(); sentiment_label = "positive" if sentiment_idx == 1 else "negative"
|
|
|
|
|
|
|
|
|
|
|
11 |
return {"sentiment": sentiment_label}
|
12 |
|
13 |
+
def sentiment_api(text):
|
|
|
|
|
|
|
|
|
14 |
output = analyze_sentiment(text)
|
15 |
+
if "error" in output: return {"error": output["error"]}
|
16 |
+
return output
|
|
stt_api.py
CHANGED
@@ -1,33 +1,17 @@
|
|
1 |
-
import os
|
2 |
-
import uuid
|
3 |
from flask import jsonify, send_file, request
|
4 |
from main import *
|
5 |
-
import torch
|
6 |
-
import torchaudio
|
7 |
|
8 |
def speech_to_text_func(audio_path):
|
9 |
-
if stt_model is None:
|
10 |
-
|
11 |
-
|
12 |
-
waveform, sample_rate = torchaudio.load(audio_path)
|
13 |
-
if waveform.ndim > 1:
|
14 |
-
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
15 |
waveform = waveform.to(device)
|
16 |
-
with torch.no_grad():
|
17 |
-
|
18 |
-
predicted_ids = torch.argmax(logits, dim=-1)
|
19 |
-
transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
|
20 |
-
|
21 |
-
return {"text": transcription}
|
22 |
|
23 |
-
def stt_api():
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
28 |
-
audio_file.save(temp_audio_path)
|
29 |
-
output = speech_to_text_func(temp_audio_path)
|
30 |
-
os.remove(temp_audio_path)
|
31 |
-
if "error" in output:
|
32 |
-
return jsonify({"error": output["error"]}), 500
|
33 |
-
return jsonify(output)
|
|
|
1 |
+
import os, uuid
|
|
|
2 |
from flask import jsonify, send_file, request
|
3 |
from main import *
|
4 |
+
import torch, torchaudio
|
|
|
5 |
|
6 |
def speech_to_text_func(audio_path):
|
7 |
+
if stt_model is None: return {"error": "STT model not initialized."}
|
8 |
+
waveform, sample_rate = torchaudio.load(audio_path);
|
9 |
+
if waveform.ndim > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
|
|
|
|
|
10 |
waveform = waveform.to(device)
|
11 |
+
with torch.no_grad(): logits = stt_model(waveform)
|
12 |
+
predicted_ids = torch.argmax(logits, dim=-1); transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist()); return {"text": transcription}
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
def stt_api(audio_filepath):
|
15 |
+
output = speech_to_text_func(audio_filepath)
|
16 |
+
if "error" in output: return {"error": output["error"]}
|
17 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
summarization_api.py
CHANGED
@@ -1,27 +1,14 @@
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
-
import torch
|
4 |
|
5 |
def summarize_text(text, output_path="output_summary.txt"):
|
6 |
-
if summarization_model is None or summarization_tokenizer is None:
|
7 |
-
return "Summarization model or tokenizer not initialized."
|
8 |
-
|
9 |
input_ids = summarization_tokenizer.encode(text, return_tensors="pt").to(device)
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
with open(output_path, "w") as file:
|
16 |
-
file.write(summary_text)
|
17 |
-
return output_path
|
18 |
-
|
19 |
-
def summarization_api():
|
20 |
-
data = request.get_json()
|
21 |
-
text = data.get('text')
|
22 |
-
if not text:
|
23 |
-
return jsonify({"error": "Text is required"}), 400
|
24 |
-
output_file = summarize_text(text)
|
25 |
-
if output_file == "Summarization model or tokenizer not initialized.":
|
26 |
-
return jsonify({"error": "Summarization failed"}), 500
|
27 |
-
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
|
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
+
import torch, io, base64
|
4 |
|
5 |
def summarize_text(text, output_path="output_summary.txt"):
|
6 |
+
if summarization_model is None or summarization_tokenizer is None: return {"error": "Summarization model or tokenizer not initialized."}
|
|
|
|
|
7 |
input_ids = summarization_tokenizer.encode(text, return_tensors="pt").to(device)
|
8 |
+
with torch.no_grad(): summary_ids = summarization_model.generate(input_ids, num_beams=4, max_length=100, early_stopping=True); summary_text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
9 |
+
return {"summary_text": summary_text}
|
10 |
|
11 |
+
def summarization_api(text):
|
12 |
+
output = summarize_text(text)
|
13 |
+
if "error" in output: return {"error": output["error"]}
|
14 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_generation.py
CHANGED
@@ -1,194 +1,93 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
import
|
5 |
-
from
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
except NameError:
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
except NameError:
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
except NameError:
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
except NameError:
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
except NameError:
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
try:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
def
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
else
|
94 |
-
repetition_count = 0
|
95 |
-
last_token = token_id
|
96 |
-
if repetition_count >= 10:
|
97 |
-
yield "<END_STREAM>"
|
98 |
-
return
|
99 |
-
generated.append(token_id)
|
100 |
-
token_decoded = decode_fn(token_id)
|
101 |
-
yield token_decoded
|
102 |
-
if end_token_condition(token_id):
|
103 |
-
yield "<END_STREAM>"
|
104 |
-
return
|
105 |
-
|
106 |
-
def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
107 |
-
context_tokens = enc.encode(prompt)
|
108 |
-
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
|
109 |
-
return _generate_sequence(
|
110 |
-
lambda ct, past: model(ct, past_key_values=past),
|
111 |
-
context_tensor,
|
112 |
-
list(context_tokens),
|
113 |
-
lambda token: enc.decode([token]),
|
114 |
-
lambda token: token == enc.encoder[END_OF_TEXT_TOKEN],
|
115 |
-
temperature, top_k, top_p, repetition_penalty, max_length
|
116 |
-
)
|
117 |
-
|
118 |
-
def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
119 |
-
context_tokens = tokenizer.encode(prompt)
|
120 |
-
context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
|
121 |
-
return _generate_sequence(
|
122 |
-
lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None),
|
123 |
-
context_tensor,
|
124 |
-
list(context_tokens),
|
125 |
-
lambda token: tokenizer.decode([token]),
|
126 |
-
lambda token: token == 50256,
|
127 |
-
temperature, top_k, top_p, repetition_penalty, max_length
|
128 |
-
)
|
129 |
-
|
130 |
-
def summarize_text(text):
|
131 |
-
if summarization_model and summarization_tokenizer:
|
132 |
-
input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
|
133 |
-
summary_ids = summarization_model.generate(
|
134 |
-
input_ids,
|
135 |
-
max_length=150,
|
136 |
-
min_length=40,
|
137 |
-
length_penalty=2.0,
|
138 |
-
num_beams=4,
|
139 |
-
early_stopping=True
|
140 |
-
)
|
141 |
-
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
142 |
-
return text[:300] + "..." if len(text) > 300 else text
|
143 |
-
|
144 |
-
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
|
145 |
-
initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"
|
146 |
-
reasoning_prompt = prev_context if prev_context else initial_prompt
|
147 |
-
ddgs = DDGS()
|
148 |
-
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
|
149 |
-
if search_results:
|
150 |
-
reasoning_prompt += "\nWeb Search Results:\n"
|
151 |
-
for result in search_results:
|
152 |
-
reasoning_prompt += "- " + result['body'] + "\n"
|
153 |
-
reasoning_prompt += "\n"
|
154 |
-
if "code" in text_input.lower() or "program" in text_input.lower():
|
155 |
-
model_type = "code"
|
156 |
-
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
|
157 |
-
model_type = "summarize"
|
158 |
-
elif model_gpt2 and enc:
|
159 |
-
model_type = "gpt2"
|
160 |
-
else:
|
161 |
-
yield "<ERROR: No se encontró un modelo adecuado>"
|
162 |
-
yield "<END_STREAM>"
|
163 |
-
return
|
164 |
-
if model_type == "summarize":
|
165 |
-
if summarization_model:
|
166 |
-
summary = summarize_text(text_input)
|
167 |
-
yield "SUMMARY_TEXT:" + summary
|
168 |
-
yield "<END_STREAM>"
|
169 |
-
return
|
170 |
-
accumulated_text = ""
|
171 |
-
current_context = reasoning_prompt
|
172 |
-
overlap = 256
|
173 |
-
while True:
|
174 |
-
if model_type == "code":
|
175 |
-
generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
|
176 |
-
elif model_type == "gpt2":
|
177 |
-
generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
|
178 |
-
chunk_text = ""
|
179 |
-
finished = False
|
180 |
-
for token in generator:
|
181 |
-
if token == "<END_STREAM>":
|
182 |
-
finished = True
|
183 |
-
break
|
184 |
-
chunk_text += token
|
185 |
-
if accumulated_text:
|
186 |
-
overlap_text = accumulated_text[-overlap:]
|
187 |
-
if chunk_text.startswith(overlap_text):
|
188 |
-
chunk_text = chunk_text[len(overlap_text):]
|
189 |
-
accumulated_text += chunk_text
|
190 |
-
yield chunk_text
|
191 |
-
if finished:
|
192 |
-
yield "<END_STREAM>"
|
193 |
-
break
|
194 |
-
current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text
|
|
|
1 |
+
import torch, torch.nn.functional as F
|
2 |
+
from tqdm import trange
|
3 |
+
import time
|
4 |
+
from tokenxxx import *
|
5 |
+
from main import *
|
6 |
+
from duckduckgo_search import DDGS
|
7 |
+
|
8 |
+
try: END_OF_TEXT_TOKEN
|
9 |
+
except NameError: END_OF_TEXT_TOKEN = ""
|
10 |
+
try: SYSTEM_PROMPT
|
11 |
+
except NameError: SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo."
|
12 |
+
try: MAX_XDD
|
13 |
+
except NameError: MAX_XDD = 5
|
14 |
+
try: codegen_model
|
15 |
+
except NameError: codegen_model = None
|
16 |
+
try: codegen_tokenizer
|
17 |
+
except NameError: codegen_tokenizer = None
|
18 |
+
try: summarization_model
|
19 |
+
except NameError: summarization_model = None
|
20 |
+
try: summarization_tokenizer
|
21 |
+
except NameError: summarization_tokenizer = None
|
22 |
+
try: model_gpt2
|
23 |
+
except NameError: model_gpt2 = None
|
24 |
+
try: enc
|
25 |
+
except NameError: enc = None
|
26 |
+
try: device
|
27 |
+
except NameError: device = "cpu"
|
28 |
+
|
29 |
+
if torch.device(device).type == "cuda": torch.backends.cudnn.benchmark = True
|
30 |
+
MAX_GENERATION_LENGTH = 512
|
31 |
+
|
32 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
33 |
+
top_k = min(top_k, logits.size(-1));
|
34 |
+
if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]; logits[indices_to_remove] = filter_value
|
35 |
+
if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1); cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
36 |
+
sorted_indices_to_remove = cumulative_probs > top_p; sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone(); sorted_indices_to_remove[..., 0] = 0; indices_to_remove = sorted_indices[sorted_indices_to_remove]; logits[indices_to_remove] = filter_value; return logits
|
37 |
+
|
38 |
+
def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length):
|
39 |
+
past_key_values = None; last_token = None; repetition_count = 0
|
40 |
+
for _ in range(max_length):
|
41 |
+
try: outputs = model_call(context_tensor, past_key_values)
|
42 |
+
except Exception as e: yield "<ERROR:" + str(e) + ">"; yield "<END_STREAM>"; return
|
43 |
+
next_token_logits = outputs[0][:, -1, :] / temperature; past_key_values = outputs[1]
|
44 |
+
for token_index in set(generated): next_token_logits[0, token_index] /= repetition_penalty
|
45 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
46 |
+
if temperature == 0: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
|
47 |
+
else: next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
48 |
+
token_id = next_token.tolist()[0][0]
|
49 |
+
if token_id == last_token: repetition_count += 1
|
50 |
+
else: repetition_count = 0; last_token = token_id
|
51 |
+
if repetition_count >= 10: yield "<END_STREAM>"; return
|
52 |
+
generated.append(token_id); token_decoded = decode_fn(token_id); yield token_decoded
|
53 |
+
if end_token_condition(token_id): yield "<END_STREAM>"; return
|
54 |
+
|
55 |
+
def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
56 |
+
context_tokens = enc.encode(prompt); context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
|
57 |
+
return _generate_sequence(lambda ct, past: model(ct, past_key_values=past), context_tensor, list(context_tokens), lambda token: enc.decode([token]), lambda token: token == enc.encoder[END_OF_TEXT_TOKEN], temperature, top_k, top_p, repetition_penalty, max_length)
|
58 |
+
|
59 |
+
def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
60 |
+
context_tokens = tokenizer.encode(prompt); context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
|
61 |
+
return _generate_sequence(lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None), context_tensor, list(context_tokens), lambda token: tokenizer.decode([token]), lambda token: token == 50256, temperature, top_k, top_p, repetition_penalty, max_length)
|
62 |
+
|
63 |
+
def summarize_text(text):
|
64 |
+
if summarization_model and summarization_tokenizer:
|
65 |
+
input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device); summary_ids = summarization_model.generate(input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
|
66 |
+
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
67 |
+
return text[:300] + "..." if len(text) > 300 else text
|
68 |
+
|
69 |
+
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
|
70 |
+
initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"; reasoning_prompt = prev_context if prev_context else initial_prompt; ddgs = DDGS()
|
71 |
+
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)];
|
72 |
+
if search_results: reasoning_prompt += "\nWeb Search Results:\n";
|
73 |
+
for result in search_results: reasoning_prompt += "- " + result['body'] + "\n"
|
74 |
+
reasoning_prompt += "\n"
|
75 |
+
if "code" in text_input.lower() or "program" in text_input.lower(): model_type = "code"
|
76 |
+
elif "summarize" in text_input.lower() or "summary" in text_input.lower(): model_type = "summarize"
|
77 |
+
elif model_gpt2 and enc: model_type = "gpt2"
|
78 |
+
else: yield "<ERROR: No se encontró un modelo adecuado>"; yield "<END_STREAM>"; return
|
79 |
+
if model_type == "summarize":
|
80 |
+
if summarization_model: summary = summarize_text(text_input); yield "SUMMARY_TEXT:" + summary; yield "<END_STREAM>"; return
|
81 |
+
accumulated_text = ""; current_context = reasoning_prompt; overlap = 256
|
82 |
+
while True:
|
83 |
+
if model_type == "code": generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
|
84 |
+
elif model_type == "gpt2": generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
|
85 |
+
chunk_text = ""; finished = False
|
86 |
+
for token in generator:
|
87 |
+
if token == "<END_STREAM>": finished = True; break
|
88 |
+
chunk_text += token
|
89 |
+
if accumulated_text: overlap_text = accumulated_text[-overlap:];
|
90 |
+
if chunk_text.startswith(overlap_text): chunk_text = chunk_text[len(overlap_text):]
|
91 |
+
accumulated_text += chunk_text; yield chunk_text
|
92 |
+
if finished: yield "<END_STREAM>"; break
|
93 |
+
current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_to_video_api.py
CHANGED
@@ -1,36 +1,24 @@
|
|
1 |
-
import os
|
2 |
-
import uuid
|
3 |
from flask import jsonify, send_file, request
|
4 |
from main import *
|
5 |
-
import torch
|
6 |
-
import io
|
7 |
from skimage import img_as_ubyte
|
8 |
-
import imageio
|
9 |
|
10 |
def text_to_video_func(prompt, output_path="output_video.mp4"):
|
11 |
-
if text_to_video_model is None:
|
12 |
-
return "Text-to-Video model not initialized."
|
13 |
video_frames_list = text_to_video_model(prompt)
|
14 |
-
if video_frames_list and hasattr(video_frames_list, 'frames'):
|
15 |
-
|
16 |
-
export_to_video_pure(video_frames, output_video=output_path)
|
17 |
-
return output_path
|
18 |
-
return "Video generation failed."
|
19 |
|
20 |
def export_to_video_pure(video_frames, output_video="output_video.mp4", fps=25):
|
21 |
writer = imageio.get_writer(output_video, fps=fps)
|
22 |
-
for frame in video_frames:
|
23 |
-
writer.append_data(img_as_ubyte(frame))
|
24 |
writer.close()
|
25 |
|
26 |
-
def text_to_video_api():
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
if output_file == "Text-to-Video model not initialized." or output_file == "Video generation failed.":
|
33 |
-
return jsonify({"error": "Text to video failed"}), 500
|
34 |
-
with open(output_file, 'rb') as f:
|
35 |
-
video_content = f.read()
|
36 |
-
return send_file(io.BytesIO(video_content), mimetype='video/mp4', as_attachment=True, download_name="output_video.mp4")
|
|
|
1 |
+
import os, uuid
|
|
|
2 |
from flask import jsonify, send_file, request
|
3 |
from main import *
|
4 |
+
import torch, io
|
|
|
5 |
from skimage import img_as_ubyte
|
6 |
+
import imageio, base64
|
7 |
|
8 |
def text_to_video_func(prompt, output_path="output_video.mp4"):
|
9 |
+
if text_to_video_model is None: return {"error": "Text-to-Video model not initialized."}
|
|
|
10 |
video_frames_list = text_to_video_model(prompt)
|
11 |
+
if video_frames_list and hasattr(video_frames_list, 'frames'): export_to_video_pure(video_frames_list.frames, output_video=output_path); return output_path
|
12 |
+
return {"error": "Video generation failed."}
|
|
|
|
|
|
|
13 |
|
14 |
def export_to_video_pure(video_frames, output_video="output_video.mp4", fps=25):
|
15 |
writer = imageio.get_writer(output_video, fps=fps)
|
16 |
+
for frame in video_frames: writer.append_data(img_as_ubyte(frame))
|
|
|
17 |
writer.close()
|
18 |
|
19 |
+
def text_to_video_api(prompt):
|
20 |
+
output_data = text_to_video_func(prompt)
|
21 |
+
if "error" in output_data: return {"error": output_data["error"]}
|
22 |
+
output_file = output_data;
|
23 |
+
with open(output_file, 'rb') as f: video_content = f.read()
|
24 |
+
video_base64 = base64.b64encode(video_content).decode('utf-8'); os.remove(output_file); return {"video_base64": video_base64, "mimetype": "video/mp4"}
|
|
|
|
|
|
|
|
|
|
tokenxxx.py
CHANGED
@@ -1,142 +1,72 @@
|
|
1 |
-
import json
|
2 |
-
import re
|
3 |
-
import unicodedata
|
4 |
from functools import lru_cache
|
5 |
-
import wget
|
6 |
-
import
|
7 |
-
from constants import *
|
8 |
import nltk
|
9 |
|
10 |
@lru_cache()
|
11 |
def bytes_to_unicode():
|
12 |
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
13 |
-
cs = bs[:]
|
14 |
-
n
|
15 |
-
for
|
16 |
-
if b not in bs:
|
17 |
-
bs.append(b)
|
18 |
-
cs.append(2**8 + n)
|
19 |
-
n += 1
|
20 |
-
cs = [chr(n) for n in cs]
|
21 |
-
return dict(zip(bs, cs))
|
22 |
|
23 |
def get_pairs(word):
|
24 |
-
pairs = set()
|
25 |
-
|
26 |
-
for char in word[1:]:
|
27 |
-
pairs.add((prev_char, char))
|
28 |
-
prev_char = char
|
29 |
-
return pairs
|
30 |
|
31 |
class Encoder:
|
32 |
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
|
33 |
-
self.encoder = encoder
|
34 |
-
self.
|
35 |
-
self.errors = errors
|
36 |
-
self.byte_encoder = bytes_to_unicode()
|
37 |
-
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
38 |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
39 |
-
self.cache = {}
|
40 |
-
if tokenize is None:
|
41 |
-
|
42 |
-
self.tokenize = lambda text: re.findall(self.pat, text)
|
43 |
-
else:
|
44 |
-
self.tokenize = tokenize
|
45 |
|
46 |
def bpe(self, token):
|
47 |
-
if token in self.cache:
|
48 |
-
|
49 |
-
|
50 |
-
pairs = get_pairs(word)
|
51 |
-
if not pairs:
|
52 |
-
return token
|
53 |
while True:
|
54 |
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
55 |
-
if bigram not in self.bpe_ranks:
|
56 |
-
|
57 |
-
first, second = bigram
|
58 |
-
new_word = []
|
59 |
-
i = 0
|
60 |
while i < len(word):
|
61 |
-
try:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
new_word.append(first+second)
|
70 |
-
i += 2
|
71 |
-
else:
|
72 |
-
new_word.append(word[i])
|
73 |
-
i += 1
|
74 |
-
new_word = tuple(new_word)
|
75 |
-
word = new_word
|
76 |
-
if len(word) == 1:
|
77 |
-
break
|
78 |
-
else:
|
79 |
-
pairs = get_pairs(word)
|
80 |
-
word = ' '.join(word)
|
81 |
-
self.cache[token] = word
|
82 |
-
return word
|
83 |
|
84 |
def encode(self, text):
|
85 |
-
bpe_tokens = []
|
86 |
-
normalized_text =
|
87 |
-
normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
|
88 |
-
normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
|
89 |
-
for token in self.tokenize(normalized_text):
|
90 |
-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
|
91 |
-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
92 |
return bpe_tokens
|
93 |
|
94 |
def decode(self, tokens):
|
95 |
-
text = ''.join([self.decoder[token] for token in tokens])
|
96 |
-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
97 |
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
|
98 |
-
sentences = nltk.sent_tokenize(decoded_text)
|
99 |
-
return ' '.join(sentences).replace("<br>", "<br>\n")
|
100 |
|
101 |
def get_encoder_gpt2():
|
102 |
-
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
|
103 |
-
|
104 |
-
if not os.path.exists(
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
with open(encoder_path, 'r') as f:
|
112 |
-
encoder = json.load(f)
|
113 |
-
with open(vocab_path, 'r', encoding="utf-8") as f:
|
114 |
-
bpe_data = f.read()
|
115 |
-
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
116 |
-
encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
|
117 |
-
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
|
118 |
-
encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
|
119 |
-
return encoder_obj
|
120 |
|
121 |
def get_codegen_tokenizer_pure(vocab_file, merges_file):
|
122 |
-
vocab = json.load(open(vocab_file))
|
123 |
-
|
124 |
-
|
125 |
-
byte_encoder =
|
126 |
-
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
127 |
-
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
|
128 |
-
tokenize = lambda text: re.findall(tokenizer_regex, text)
|
129 |
-
encoder_obj = Encoder(
|
130 |
-
encoder=vocab,
|
131 |
-
bpe_merges=bpe_merges,
|
132 |
-
byte_encoder=byte_encoder,
|
133 |
-
byte_decoder=byte_decoder,
|
134 |
-
tokenize=tokenize
|
135 |
-
)
|
136 |
-
return encoder_obj
|
137 |
-
|
138 |
-
def codegen_tokenize(text, tokenizer):
|
139 |
-
return tokenizer.encode(text)
|
140 |
|
141 |
-
def
|
142 |
-
|
|
|
1 |
+
import json, re, unicodedata
|
|
|
|
|
2 |
from functools import lru_cache
|
3 |
+
import wget, os
|
4 |
+
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
|
|
|
5 |
import nltk
|
6 |
|
7 |
@lru_cache()
|
8 |
def bytes_to_unicode():
|
9 |
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
10 |
+
cs = bs[:]; n = 0
|
11 |
+
for b in range(2**8): if b not in bs: bs.append(b); cs.append(2**8 + n); n += 1
|
12 |
+
cs = [chr(n) for n in cs]; return dict(zip(bs, cs))
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def get_pairs(word):
|
15 |
+
pairs = set(); prev_char = word[0]
|
16 |
+
for char in word[1:]: pairs.add((prev_char, char)); prev_char = char; return pairs
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class Encoder:
|
19 |
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
|
20 |
+
self.encoder = encoder; self.decoder = {v:k for k,v in self.encoder.items()}; self.errors = errors
|
21 |
+
self.byte_encoder = bytes_to_unicode(); self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
|
|
|
|
|
|
22 |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
23 |
+
self.cache = {};
|
24 |
+
if tokenize is None: self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE); self.tokenize = lambda text: re.findall(self.pat, text)
|
25 |
+
else: self.tokenize = tokenize
|
|
|
|
|
|
|
26 |
|
27 |
def bpe(self, token):
|
28 |
+
if token in self.cache: return self.cache[token]
|
29 |
+
word = tuple(token); pairs = get_pairs(word)
|
30 |
+
if not pairs: return token
|
|
|
|
|
|
|
31 |
while True:
|
32 |
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
33 |
+
if bigram not in self.bpe_ranks: break
|
34 |
+
first, second = bigram; new_word = []; i = 0
|
|
|
|
|
|
|
35 |
while i < len(word):
|
36 |
+
try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
|
37 |
+
except ValueError: new_word.extend(word[i:]); break
|
38 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
|
39 |
+
else: new_word.append(word[i]); i += 1
|
40 |
+
new_word = tuple(new_word); word = new_word
|
41 |
+
if len(word) == 1: break
|
42 |
+
else: pairs = get_pairs(word)
|
43 |
+
word = ' '.join(word); self.cache[token] = word; return word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def encode(self, text):
|
46 |
+
bpe_tokens = []; normalized_text = unicodedata.normalize('NFKC', text); normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t'); normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
|
47 |
+
for token in self.tokenize(normalized_text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore')); bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
|
|
|
|
|
|
|
|
|
|
48 |
return bpe_tokens
|
49 |
|
50 |
def decode(self, tokens):
|
51 |
+
text = ''.join([self.decoder[token] for token in tokens]); text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
|
|
52 |
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
|
53 |
+
sentences = nltk.sent_tokenize(decoded_text); return ' '.join(sentences).replace("<br>", "<br>\n")
|
|
|
54 |
|
55 |
def get_encoder_gpt2():
|
56 |
+
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE); vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
|
57 |
+
if not os.path.exists(GPT2_FOLDER): os.makedirs(GPT2_FOLDER)
|
58 |
+
if not os.path.exists(encoder_path): wget.download(ENCODER_URL, out=encoder_path)
|
59 |
+
if not os.path.exists(vocab_path): wget.download(VOCAB_URL, out=vocab_path)
|
60 |
+
with open(encoder_path, 'r') as f: encoder = json.load(f)
|
61 |
+
with open(vocab_path, 'r', encoding="utf-8") as f: bpe_data = f.read()
|
62 |
+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]; encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
|
63 |
+
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder); encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN; return encoder_obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def get_codegen_tokenizer_pure(vocab_file, merges_file):
|
66 |
+
vocab = json.load(open(vocab_file)); merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]; bpe_merges = [tuple(m.split()) for m in merges]
|
67 |
+
byte_encoder = bytes_to_unicode(); byte_decoder = {v: k for k, v in byte_encoder.items()}
|
68 |
+
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'''); tokenize = lambda text: re.findall(tokenizer_regex, text)
|
69 |
+
encoder_obj = Encoder(encoder=vocab, bpe_merges=bpe_merges, byte_encoder=byte_encoder, byte_decoder=byte_decoder, tokenize=tokenize); return encoder_obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
def codegen_tokenize(text, tokenizer): return tokenizer.encode(text)
|
72 |
+
def codegen_decode(tokens, tokenizer): return tokenizer.decode(tokens)
|
translation_api.py
CHANGED
@@ -2,23 +2,14 @@ from flask import jsonify, send_file, request
|
|
2 |
from main import *
|
3 |
|
4 |
def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX'):
|
5 |
-
if translation_model is None:
|
6 |
-
return {"error": "Translation model not initialized."}
|
7 |
-
|
8 |
encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
9 |
generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
|
10 |
-
translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
11 |
-
|
12 |
-
return {"translated_text": translation}
|
13 |
|
14 |
-
def translation_api():
|
15 |
-
data = request.get_json()
|
16 |
-
text
|
17 |
-
target_lang = data.get('target_lang', 'es')
|
18 |
-
source_lang = data.get('source_lang', 'en')
|
19 |
-
if not text:
|
20 |
-
return jsonify({"error": "Text is required"}), 400
|
21 |
output = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
|
22 |
-
if "error" in output:
|
23 |
-
|
24 |
-
return jsonify(output)
|
|
|
2 |
from main import *
|
3 |
|
4 |
def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX'):
|
5 |
+
if translation_model is None: return {"error": "Translation model not initialized."}
|
|
|
|
|
6 |
encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
7 |
generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
|
8 |
+
translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]; return {"translated_text": translation}
|
|
|
|
|
9 |
|
10 |
+
def translation_api(text):
|
11 |
+
data = request.get_json(); text = data.get('text'); target_lang = data.get('target_lang', 'es'); source_lang = data.get('source_lang', 'en')
|
12 |
+
if not text: return jsonify({"error": "Text is required"}), 400
|
|
|
|
|
|
|
|
|
13 |
output = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
|
14 |
+
if "error" in output: return jsonify({"error": output["error"]}), 500
|
15 |
+
return output
|
|
tts_api.py
CHANGED
@@ -1,26 +1,15 @@
|
|
1 |
-
import os
|
2 |
from flask import jsonify, send_file, request
|
3 |
from main import *
|
4 |
-
import torch
|
5 |
-
import torchaudio
|
6 |
-
import uuid
|
7 |
|
8 |
-
def text_to_speech_func(text):
|
9 |
-
if tts_model is None:
|
10 |
-
return {"error": "TTS model not initialized."}
|
11 |
input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
12 |
-
with torch.no_grad():
|
13 |
-
|
14 |
-
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
15 |
-
torchaudio.save(temp_audio_path, audio_output.cpu(), 16000)
|
16 |
-
return temp_audio_path
|
17 |
|
18 |
-
def tts_api():
|
19 |
-
data = request.get_json()
|
20 |
-
text = data.get('text')
|
21 |
-
if not text:
|
22 |
-
return jsonify({"error": "Text is required"}), 400
|
23 |
output_file = text_to_speech_func(text)
|
24 |
-
if "error" in
|
25 |
-
|
26 |
-
|
|
|
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
+
import torch, torchaudio, uuid, io, base64
|
|
|
|
|
4 |
|
5 |
+
def text_to_speech_func(text, output_path="output_audio.wav"):
|
6 |
+
if tts_model is None: return {"error": "TTS model not initialized."}
|
|
|
7 |
input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
8 |
+
with torch.no_grad(): audio_output = tts_model(input_tokens['input_ids'])
|
9 |
+
torchaudio.save(output_path, audio_output.cpu(), 16000); return output_path
|
|
|
|
|
|
|
10 |
|
11 |
+
def tts_api(text):
|
|
|
|
|
|
|
|
|
12 |
output_file = text_to_speech_func(text)
|
13 |
+
if isinstance(output_file, dict) and "error" in output_file: return {"error": output_file["error"]}
|
14 |
+
with open(output_file, 'rb') as f: audio_content = f.read()
|
15 |
+
audio_base64 = base64.b64encode(audio_content).decode('utf-8'); os.remove(output_file); return {"audio_base64": audio_base64, "mimetype": "audio/wav"}
|
xtts_api.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import jsonify, send_file, request
|
2 |
+
from main import *
|
3 |
+
import torch, torchaudio, io, base64, uuid, os
|
4 |
+
|
5 |
+
def xtts_clone_func(text, audio_sample_path, output_path="output_xtts_audio.wav"):
|
6 |
+
if xtts_model is None: return {"error": "XTTS model not initialized."}
|
7 |
+
language = "en"; speaker_id = 0
|
8 |
+
try:
|
9 |
+
with torch.no_grad(): wav = xtts_model.inference(text=text, language_id=language, speaker_id=speaker_id, voice_sample=audio_sample_path, temperature=0.7, length_penalty=1.0)
|
10 |
+
except Exception as e: return {"error": f"XTTS inference failed: {e}"}
|
11 |
+
torchaudio.save(output_path, wav, 24000); return output_path
|
12 |
+
|
13 |
+
def xtts_api(inputs):
|
14 |
+
text = inputs[0]; audio_sample_filepath = inputs[1]
|
15 |
+
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"; os.rename(audio_sample_filepath, temp_audio_path)
|
16 |
+
output = xtts_clone_func(text, temp_audio_path); os.remove(temp_audio_path)
|
17 |
+
if isinstance(output, dict) and "error" in output: return {"error": output["error"]}
|
18 |
+
output_file = output
|
19 |
+
with open(output_file, 'rb') as f: audio_content = f.read()
|
20 |
+
audio_base64 = base64.b64encode(audio_content).decode('utf-8'); os.remove(output_file); return {"audio_base64": audio_base64, "mimetype": "audio/wav"}
|
21 |
+
--- END OF FILE xtts_api.py ---
|
xxx.py
CHANGED
@@ -1,142 +1,71 @@
|
|
1 |
-
import json
|
2 |
-
import re
|
3 |
-
import unicodedata
|
4 |
from functools import lru_cache
|
5 |
-
import wget
|
6 |
-
import os
|
7 |
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
|
8 |
-
import nltk
|
9 |
|
10 |
@lru_cache()
|
11 |
def bytes_to_unicode():
|
12 |
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
13 |
-
cs = bs[:]
|
14 |
-
n
|
15 |
-
for
|
16 |
-
if b not in bs:
|
17 |
-
bs.append(b)
|
18 |
-
cs.append(2**8 + n)
|
19 |
-
n += 1
|
20 |
-
cs = [chr(n) for n in cs]
|
21 |
-
return dict(zip(bs, cs))
|
22 |
|
23 |
def get_pairs(word):
|
24 |
-
pairs = set()
|
25 |
-
|
26 |
-
for char in word[1:]:
|
27 |
-
pairs.add((prev_char, char))
|
28 |
-
prev_char = char
|
29 |
-
return pairs
|
30 |
|
31 |
class Encoder:
|
32 |
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
|
33 |
-
self.encoder = encoder
|
34 |
-
self.
|
35 |
-
self.errors = errors
|
36 |
-
self.byte_encoder = bytes_to_unicode()
|
37 |
-
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
38 |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
39 |
-
self.cache = {}
|
40 |
-
if tokenize is None:
|
41 |
-
|
42 |
-
self.tokenize = lambda text: re.findall(self.pat, text)
|
43 |
-
else:
|
44 |
-
self.tokenize = tokenize
|
45 |
|
46 |
def bpe(self, token):
|
47 |
-
if token in self.cache:
|
48 |
-
|
49 |
-
|
50 |
-
pairs = get_pairs(word)
|
51 |
-
if not pairs:
|
52 |
-
return token
|
53 |
while True:
|
54 |
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
55 |
-
if bigram not in self.bpe_ranks:
|
56 |
-
|
57 |
-
first, second = bigram
|
58 |
-
new_word = []
|
59 |
-
i = 0
|
60 |
while i < len(word):
|
61 |
-
try:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
new_word.append(first+second)
|
70 |
-
i += 2
|
71 |
-
else:
|
72 |
-
new_word.append(word[i])
|
73 |
-
i += 1
|
74 |
-
new_word = tuple(new_word)
|
75 |
-
word = new_word
|
76 |
-
if len(word) == 1:
|
77 |
-
break
|
78 |
-
else:
|
79 |
-
pairs = get_pairs(word)
|
80 |
-
word = ' '.join(word)
|
81 |
-
self.cache[token] = word
|
82 |
-
return word
|
83 |
|
84 |
def encode(self, text):
|
85 |
-
bpe_tokens = []
|
86 |
-
normalized_text =
|
87 |
-
normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
|
88 |
-
normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
|
89 |
-
for token in self.tokenize(normalized_text):
|
90 |
-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
|
91 |
-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
92 |
return bpe_tokens
|
93 |
|
94 |
def decode(self, tokens):
|
95 |
-
text = ''.join([self.decoder[token] for token in tokens])
|
96 |
-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
97 |
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
|
98 |
-
sentences = nltk.sent_tokenize(decoded_text)
|
99 |
-
return ' '.join(sentences).replace("<br>", "<br>\n")
|
100 |
|
101 |
def get_encoder_gpt2():
|
102 |
-
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
|
103 |
-
|
104 |
-
if not os.path.exists(
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
with open(encoder_path, 'r') as f:
|
112 |
-
encoder = json.load(f)
|
113 |
-
with open(vocab_path, 'r', encoding="utf-8") as f:
|
114 |
-
bpe_data = f.read()
|
115 |
-
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
116 |
-
encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
|
117 |
-
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
|
118 |
-
encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
|
119 |
-
return encoder_obj
|
120 |
|
121 |
def get_codegen_tokenizer_pure(vocab_file, merges_file):
|
122 |
-
vocab = json.load(open(vocab_file))
|
123 |
-
|
124 |
-
|
125 |
-
byte_encoder =
|
126 |
-
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
127 |
-
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
|
128 |
-
tokenize = lambda text: re.findall(tokenizer_regex, text)
|
129 |
-
encoder_obj = Encoder(
|
130 |
-
encoder=vocab,
|
131 |
-
bpe_merges=bpe_merges,
|
132 |
-
byte_encoder=byte_encoder,
|
133 |
-
byte_decoder=byte_decoder,
|
134 |
-
tokenize=tokenize
|
135 |
-
)
|
136 |
-
return encoder_obj
|
137 |
-
|
138 |
-
def codegen_tokenize(text, tokenizer):
|
139 |
-
return tokenizer.encode(text)
|
140 |
|
141 |
-
def
|
142 |
-
|
|
|
1 |
+
import json, re, unicodedata
|
|
|
|
|
2 |
from functools import lru_cache
|
3 |
+
import wget, os
|
|
|
4 |
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
|
|
|
5 |
|
6 |
@lru_cache()
|
7 |
def bytes_to_unicode():
|
8 |
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
9 |
+
cs = bs[:]; n = 0
|
10 |
+
for b in range(2**8): if b not in bs: bs.append(b); cs.append(2**8 + n); n += 1
|
11 |
+
cs = [chr(n) for n in cs]; return dict(zip(bs, cs))
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def get_pairs(word):
|
14 |
+
pairs = set(); prev_char = word[0]
|
15 |
+
for char in word[1:]: pairs.add((prev_char, char)); prev_char = char; return pairs
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class Encoder:
|
18 |
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
|
19 |
+
self.encoder = encoder; self.decoder = {v:k for k,v in self.encoder.items()}; self.errors = errors
|
20 |
+
self.byte_encoder = bytes_to_unicode(); self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
|
|
|
|
|
|
21 |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
22 |
+
self.cache = {};
|
23 |
+
if tokenize is None: self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE); self.tokenize = lambda text: re.findall(self.pat, text)
|
24 |
+
else: self.tokenize = tokenize
|
|
|
|
|
|
|
25 |
|
26 |
def bpe(self, token):
|
27 |
+
if token in self.cache: return self.cache[token]
|
28 |
+
word = tuple(token); pairs = get_pairs(word)
|
29 |
+
if not pairs: return token
|
|
|
|
|
|
|
30 |
while True:
|
31 |
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
32 |
+
if bigram not in self.bpe_ranks: break
|
33 |
+
first, second = bigram; new_word = []; i = 0
|
|
|
|
|
|
|
34 |
while i < len(word):
|
35 |
+
try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
|
36 |
+
except ValueError: new_word.extend(word[i:]); break
|
37 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
|
38 |
+
else: new_word.append(word[i]); i += 1
|
39 |
+
new_word = tuple(new_word); word = new_word
|
40 |
+
if len(word) == 1: break
|
41 |
+
else: pairs = get_pairs(word)
|
42 |
+
word = ' '.join(word); self.cache[token] = word; return word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def encode(self, text):
|
45 |
+
bpe_tokens = []; normalized_text = unicodedata.normalize('NFKC', text); normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t'); normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
|
46 |
+
for token in self.tokenize(normalized_text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore')); bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
|
|
|
|
|
|
|
|
|
|
47 |
return bpe_tokens
|
48 |
|
49 |
def decode(self, tokens):
|
50 |
+
text = ''.join([self.decoder[token] for token in tokens]); text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
|
|
51 |
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
|
52 |
+
sentences = nltk.sent_tokenize(decoded_text); return ' '.join(sentences).replace("<br>", "<br>\n")
|
|
|
53 |
|
54 |
def get_encoder_gpt2():
|
55 |
+
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE); vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
|
56 |
+
if not os.path.exists(GPT2_FOLDER): os.makedirs(GPT2_FOLDER)
|
57 |
+
if not os.path.exists(encoder_path): wget.download(ENCODER_URL, out=encoder_path)
|
58 |
+
if not os.path.exists(vocab_path): wget.download(VOCAB_URL, out=vocab_path)
|
59 |
+
with open(encoder_path, 'r') as f: encoder = json.load(f)
|
60 |
+
with open(vocab_path, 'r', encoding="utf-8") as f: bpe_data = f.read()
|
61 |
+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]; encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
|
62 |
+
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder); encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN; return encoder_obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def get_codegen_tokenizer_pure(vocab_file, merges_file):
|
65 |
+
vocab = json.load(open(vocab_file)); merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]; bpe_merges = [tuple(m.split()) for m in merges]
|
66 |
+
byte_encoder = bytes_to_unicode(); byte_decoder = {v: k for k, v in byte_encoder.items()}
|
67 |
+
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'''); tokenize = lambda text: re.findall(tokenizer_regex, text)
|
68 |
+
encoder_obj = Encoder(encoder=vocab, bpe_merges=bpe_merges, byte_encoder=byte_encoder, byte_decoder=byte_decoder, tokenize=tokenize); return encoder_obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
def codegen_tokenize(text, tokenizer): return tokenizer.encode(text)
|
71 |
+
def codegen_decode(tokens, tokenizer): return tokenizer.decode(tokens)
|