import threading import queue import time import os import nltk import re import json from flask import Flask from flask_cors import CORS from api import * from extensions import * from constants import * from configs import * from tokenxxx import * from models import * from model_loader import * from utils import * from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue from text_generation import * from sadtalker_utils import * import torch state_dict = None enc = None config = None model_gpt2 = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") news_clf = None tfidf_vectorizer = None text_queue = queue.Queue() categories = None background_threads = [] feedback_queue = queue.Queue() reasoning_queue = queue.Queue() seen_responses = set() dialogue_history = [] vocabulary = set() word_to_index = {} index_to_word = [] translation_model = None sp = None codegen_model = None codegen_tokenizer = None codegen_vocabulary = None codegen_index_to_word = None codegen_word_to_index = None summarization_model = None summarization_vocabulary = set() summarization_word_to_index = {} summarization_index_to_word = [] sadtalker_instance = None imagegen_model = None image_to_3d_model = None text_to_video_model = None stream_type = "text" sentiment_model = None stt_model = None tts_model = None musicgen_model = None def load_models(): global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, checkpoint_path, gfpgan_model_file, restoreformer_model_file, codeformer_model_file, realesrgan_model_file, kp_file, aud_file, wav_file, gen_file, mapx_file, den_file model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG}) translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS) codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS) summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS) imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS) image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS) text_to_video_model = initialize_text_to_video_model(TEXT_TO_VIDEO_FOLDER, TEXT_TO_VIDEO_FILES_URLS) sentiment_model = initialize_sentiment_model(SENTIMENT_FOLDER, SENTIMENT_FILES_URLS) stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS) tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS) musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS) class SimpleClassifier(torch.nn.Module): def __init__(self, vocab_size, num_classes): super(SimpleClassifier, self).__init__() self.embedding = torch.nn.Embedding(vocab_size, 128) self.linear = torch.nn.Linear(128, num_classes) def forward(self, x): embedded = self.embedding(x) pooled = torch.mean(embedded, dim=1) return self.linear(pooled) def tokenize_text(text): global vocabulary, word_to_index, index_to_word tokens = text.lower().split() for token in tokens: if token not in vocabulary: vocabulary.add(token) word_to_index[token] = len(index_to_word) index_to_word.append(token) return tokens def text_to_vector(text): global vocabulary, word_to_index tokens = tokenize_text(text) vector = torch.zeros(len(vocabulary)) for token in tokens: if token in word_to_index: vector[word_to_index[token]] += 1 return vector if __name__ == "__main__": nltk.download('punkt') load_models() categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5'] import background_tasks background_tasks.categories = categories background_tasks.text_queue = text_queue background_tasks.reasoning_queue = reasoning_queue background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True)) background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True)) background_threads.append(threading.Thread(target=background_training, daemon=True)) background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True)) for thread in background_threads: thread.start() app.run(host='0.0.0.0', port=7860)