Hhhh / api.py
Hjgugugjhuhjggg's picture
V2
0ff6756 verified
raw
history blame
16.7 kB
from main import *
from tts_api import *
from stt_api import *
from sentiment_api import *
from imagegen_api import *
from musicgen_api import *
from translation_api import *
from codegen_api import *
from text_to_video_api import *
from summarization_api import *
from image_to_3d_api import *
from flask import Flask, request, jsonify, Response, send_file, stream_with_context
from flask_cors import CORS
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
from PIL import Image
import io
import tempfile
import queue
import json
import base64
app = Flask(__name__)
CORS(app)
html_code = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI Text Generation</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: #f0f0f0;
color: #333;
margin: 0;
padding: 0;
display: flex;
flex-direction: column;
align-items: center;
min-height: 100vh;
}
.container {
width: 95%;
max-width: 900px;
padding: 20px;
background-color: #fff;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
border-radius: 8px;
margin-top: 20px;
margin-bottom: 20px;
display: flex;
flex-direction: column;
}
.header {
text-align: center;
margin-bottom: 20px;
}
.header h1 {
font-size: 2em;
color: #333;
}
.form-group {
margin-bottom: 15px;
}
.form-group textarea {
width: 100%;
padding: 10px;
border: 1px solid #ccc;
border-radius: 5px;
font-size: 16px;
box-sizing: border-box;
resize: vertical;
}
button {
padding: 10px 15px;
border: none;
border-radius: 5px;
background-color: #007bff;
color: white;
font-size: 18px;
cursor: pointer;
transition: background-color 0.3s ease;
}
button:hover {
background-color: #0056b3;
}
#output {
margin-top: 20px;
padding: 15px;
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f9f9f9;
white-space: pre-wrap;
word-break: break-word;
overflow-y: auto;
max-height: 100vh;
}
#output strong {
font-weight: bold;
}
.animated-text {
position: fixed;
top: 20px;
left: 20px;
font-size: 1.5em;
color: rgba(0, 0, 0, 0.1);
pointer-events: none;
z-index: -1;
}
@media (max-width: 768px) {
.container {
width: 98%;
margin-top: 10px;
margin-bottom: 10px;
padding: 15px;
}
.header h1 {
font-size: 1.8em;
}
.form-group textarea, .form-group input[type="text"] {
font-size: 14px;
padding: 8px;
}
button {
font-size: 16px;
padding: 8px 12px;
}
#output {
font-size: 14px;
padding: 10px;
margin-top: 15px;
}
}
</style>
</head>
<body>
<div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
<div class="container">
<div class="header animate__animated animate__fadeInDown">
</div>
<div class="form-group animate__animated animate__fadeInLeft">
<textarea id="text" rows="5" placeholder="Enter text"></textarea>
</div>
<button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
<div id="output" class="animate__animated">
<strong>Response:</strong><br>
<span id="generatedText"></span>
</div>
</div>
<script>
let eventSource = null;
let accumulatedText = "";
let lastResponse = "";
async function generateText() {
const inputText = document.getElementById("text").value;
document.getElementById("generatedText").innerText = "";
accumulatedText = "";
if (eventSource) {
eventSource.close();
}
const temp = 0.7;
const top_k_val = 40;
const top_p_val = 0.0;
const repetition_penalty_val = 1.2;
const requestData = {
text: inputText,
temp: temp,
top_k: top_k_val,
top_p: top_p_val,
reppenalty: repetition_penalty_val
};
const params = new URLSearchParams(requestData).toString();
eventSource = new EventSource('/api/v1/generate_stream?' + params);
eventSource.onmessage = function(event) {
if (event.data === "<END_STREAM>") {
eventSource.close();
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
if (currentResponse === lastResponse.trim()) {
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
} else {
lastResponse = currentResponse;
}
document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
return;
}
accumulatedText += event.data;
let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
document.getElementById("generatedText").innerHTML = marked.parse(partialText);
};
eventSource.onerror = function(error) {
console.error("SSE error", error);
eventSource.close();
};
const outputDiv = document.getElementById("output");
outputDiv.classList.add("show");
}
function base64ToBlob(base64Data, contentType) {
contentType = contentType || '';
const sliceSize = 1024;
const byteCharacters = atob(base64Data);
const bytesLength = byteCharacters.length;
const slicesCount = Math.ceil(bytesLength / sliceSize);
const byteArrays = new Array(slicesCount);
for (let sliceIndex = 0; sliceIndex < slicesCount; ++sliceIndex) {
const begin = sliceIndex * sliceSize;
const end = Math.min(begin + sliceSize, bytesLength);
const bytes = new Array(end - begin);
for (let offset = begin, i = 0; offset < end; ++i, ++offset) {
bytes[i] = byteCharacters[offset].charCodeAt(0);
}
byteArrays[sliceIndex] = new Uint8Array(bytes);
}
return new Blob(byteArrays, { type: contentType });
}
</script>
</body>
</html>
"""
feedback_queue = queue.Queue()
class TextGenerationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super(TextGenerationModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden=None):
x = self.embedding(x)
out, hidden = self.rnn(x, hidden)
out = self.fc(out)
return out, hidden
vocab = ["hola", "mundo", "este", "es", "un", "ejemplo", "de", "texto", "generado", "con", "torch"]
vocab_size = len(vocab)
embed_dim = 16
hidden_dim = 32
text_model = TextGenerationModel(vocab_size, embed_dim, hidden_dim)
text_model.eval()
def tokenize(text):
tokens = text.lower().split()
indices = [vocab.index(token) if token in vocab else 0 for token in tokens]
return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
def perform_reasoning_stream(text, temperature, top_k, top_p, repetition_penalty):
input_tensor = tokenize(text)
hidden = None
while True:
outputs, hidden = text_model(input_tensor, hidden)
logits = outputs[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, min(top_k, logits.shape[-1]))
chosen_index = topk_indices[0, torch.multinomial(topk_probs[0], 1).item()].item()
token_str = vocab[chosen_index]
yield token_str
input_tensor = torch.cat([input_tensor, torch.tensor([[chosen_index]], dtype=torch.long)], dim=1)
if token_str == "mundo":
yield "<END_STREAM>"
break
class SentimentModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SentimentModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
sentiment_model = SentimentModel(10, 16, 2)
sentiment_model.eval()
@app.route("/")
def index():
return html_code
@app.route("/api/v1/generate_stream", methods=["GET"])
def generate_stream():
text = request.args.get("text", "")
temp = float(request.args.get("temp", 0.7))
top_k = int(request.args.get("top_k", 40))
top_p = float(request.args.get("top_p", 0.0))
reppenalty = float(request.args.get("reppenalty", 1.2))
@stream_with_context
def event_stream():
try:
for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
if token == "<END_STREAM>":
yield "data: <END_STREAM>\n\n"
break
yield "data: " + token + "\n\n"
except Exception as e:
yield "data: <ERROR>\n\n"
return Response(event_stream(), mimetype="text/event-stream")
@app.route("/api/v1/generate", methods=["POST"])
def generate():
data = request.get_json()
text = data.get("text", "")
temp = float(data.get("temp", 0.7))
top_k = int(data.get("top_k", 40))
top_p = float(data.get("top_p", 0.0))
reppenalty = float(data.get("reppenalty", 1.2))
result = ""
try:
for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
if token == "<END_STREAM>":
break
result += token + " "
except Exception as e:
return jsonify({"error": str(e)}), 500
return jsonify({"solidity": result.strip()})
@app.route("/api/v1/feedback", methods=["POST"])
def feedback():
data = request.get_json()
feedback_text = data.get("feedback_text")
correct_category = data.get("correct_category")
if feedback_text and correct_category:
feedback_queue.put((feedback_text, correct_category))
return jsonify({"status": "feedback received"})
return jsonify({"status": "feedback failed"}), 400
@app.route("/api/v1/tts", methods=["POST"])
def tts_api():
data = request.get_json()
text = data.get("text", "")
sr = 22050
duration = 3.0
t = torch.linspace(0, duration, int(sr * duration))
frequency = 440.0
audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
audio = audio.unsqueeze(0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, audio, sr)
tmp_path = tmp.name
return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
@app.route("/api/v1/stt", methods=["POST"])
def stt_api():
data = request.get_json()
audio_b64 = data.get("audio", "")
if audio_b64:
audio_bytes = base64.b64decode(audio_b64)
buf = io.BytesIO(audio_bytes)
waveform, sr = torchaudio.load(buf)
mean_amp = waveform.abs().mean().item()
recognized_text = f"Audio processed with mean amplitude {mean_amp:.3f}"
return jsonify({"text": recognized_text})
return jsonify({"text": ""})
@app.route("/api/v1/sentiment", methods=["POST"])
def sentiment_api():
data = request.get_json()
text = data.get("text", "")
if not text:
return jsonify({"sentiment": "neutral"})
ascii_vals = [ord(c) for c in text[:10]]
while len(ascii_vals) < 10:
ascii_vals.append(0)
features = torch.tensor(ascii_vals, dtype=torch.float32).unsqueeze(0)
output = sentiment_model(features)
sentiment_idx = torch.argmax(output, dim=1).item()
sentiment = "positivo" if sentiment_idx == 1 else "negativo"
return jsonify({"sentiment": sentiment})
@app.route("/api/v1/imagegen", methods=["POST"])
def imagegen_api():
data = request.get_json()
prompt = data.get("prompt", "")
image_tensor = torch.rand(3, 256, 256)
np_image = image_tensor.mul(255).clamp(0, 255).byte().numpy().transpose(1, 2, 0)
img = Image.fromarray(np_image)
buf = io.BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
return send_file(buf, mimetype="image/png", as_attachment=True, download_name="image.png")
@app.route("/api/v1/musicgen", methods=["POST"])
def musicgen_api():
data = request.get_json()
prompt = data.get("prompt", "")
sr = 22050
duration = 5.0
t = torch.linspace(0, duration, int(sr * duration))
frequency = 440.0
audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
audio = audio.unsqueeze(0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, tmp.name, sr)
tmp_path = tmp.name
return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="music.wav")
@app.route("/api/v1/translation", methods=["POST"])
def translation_api():
data = request.get_json()
text = data.get("text", "")
translated = " ".join(text.split()[::-1])
return jsonify({"translated_text": translated})
@app.route("/api/v1/codegen", methods=["POST"])
def codegen_api():
data = request.get_json()
prompt = data.get("prompt", "")
generated_code = f"# Generated code based on prompt: {prompt}\nprint('Hello from Torch-generated code')"
return jsonify({"code": generated_code})
@app.route("/api/v1/text_to_video", methods=["POST"])
def text_to_video_api():
data = request.get_json()
prompt = data.get("prompt", "")
video_tensor = torch.randint(0, 255, (10, 3, 64, 64), dtype=torch.uint8)
video_bytes = video_tensor.numpy().tobytes()
buf = io.BytesIO(video_bytes)
return send_file(buf, mimetype="video/mp4", as_attachment=True, download_name="video.mp4")
@app.route("/api/v1/summarization", methods=["POST"])
def summarization_api():
data = request.get_json()
text = data.get("text", "")
sentences = text.split('.')
summary = sentences[0] if sentences[0] else text
return jsonify({"summary": summary})
@app.route("/api/v1/image_to_3d", methods=["POST"])
def image_to_3d_api():
data = request.get_json()
prompt = data.get("prompt", "")
obj_data = "o Cube\nv 0 0 0\nv 1 0 0\nv 1 1 0\nv 0 1 0\nf 1 2 3 4"
buf = io.BytesIO(obj_data.encode("utf-8"))
return send_file(buf, mimetype="text/plain", as_attachment=True, download_name="model.obj")
@app.route("/api/v1/sadtalker", methods=["GET"])
def sadtalker():
return jsonify({"message": "Respuesta de sadtalker"})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)