Hhhh / main.py
Hjgugugjhuhjggg's picture
V2
0ff6756 verified
raw
history blame
4.93 kB
import threading
import queue
import time
import os
import nltk
import re
import json
from flask import Flask
from flask_cors import CORS
from api import *
from extensions import *
from constants import *
from configs import *
from tokenxxx import *
from models import *
from model_loader import *
from utils import *
from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue
from text_generation import *
from sadtalker_utils import *
import torch
state_dict = None
enc = None
config = None
model_gpt2 = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
news_clf = None
tfidf_vectorizer = None
text_queue = queue.Queue()
categories = None
background_threads = []
feedback_queue = queue.Queue()
reasoning_queue = queue.Queue()
seen_responses = set()
dialogue_history = []
vocabulary = set()
word_to_index = {}
index_to_word = []
translation_model = None
sp = None
codegen_model = None
codegen_tokenizer = None
codegen_vocabulary = None
codegen_index_to_word = None
codegen_word_to_index = None
summarization_model = None
summarization_vocabulary = set()
summarization_word_to_index = {}
summarization_index_to_word = []
sadtalker_instance = None
imagegen_model = None
image_to_3d_model = None
text_to_video_model = None
stream_type = "text"
sentiment_model = None
stt_model = None
tts_model = None
musicgen_model = None
def load_models():
global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, checkpoint_path, gfpgan_model_file, restoreformer_model_file, codeformer_model_file, realesrgan_model_file, kp_file, aud_file, wav_file, gen_file, mapx_file, den_file
model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
text_to_video_model = initialize_text_to_video_model(TEXT_TO_VIDEO_FOLDER, TEXT_TO_VIDEO_FILES_URLS)
sentiment_model = initialize_sentiment_model(SENTIMENT_FOLDER, SENTIMENT_FILES_URLS)
stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
class SimpleClassifier(torch.nn.Module):
def __init__(self, vocab_size, num_classes):
super(SimpleClassifier, self).__init__()
self.embedding = torch.nn.Embedding(vocab_size, 128)
self.linear = torch.nn.Linear(128, num_classes)
def forward(self, x):
embedded = self.embedding(x)
pooled = torch.mean(embedded, dim=1)
return self.linear(pooled)
def tokenize_text(text):
global vocabulary, word_to_index, index_to_word
tokens = text.lower().split()
for token in tokens:
if token not in vocabulary:
vocabulary.add(token)
word_to_index[token] = len(index_to_word)
index_to_word.append(token)
return tokens
def text_to_vector(text):
global vocabulary, word_to_index
tokens = tokenize_text(text)
vector = torch.zeros(len(vocabulary))
for token in tokens:
if token in word_to_index:
vector[word_to_index[token]] += 1
return vector
if __name__ == "__main__":
nltk.download('punkt')
load_models()
categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
import background_tasks
background_tasks.categories = categories
background_tasks.text_queue = text_queue
background_tasks.reasoning_queue = reasoning_queue
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True))
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
background_threads.append(threading.Thread(target=background_training, daemon=True))
background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
for thread in background_threads:
thread.start()
app.run(host='0.0.0.0', port=7860)