Hhhh / summarization_api.py
Hjgugugjhuhjggg's picture
Upload 27 files
1c817fd verified
raw
history blame
1.29 kB
from flask import jsonify, send_file, request
from main import *
#from main import import summarization_model, summarization_word_to_index, device
import torch
def summarize_text(text, output_path="output_summary.txt"):
if summarization_model is None:
return "Summarization model not initialized."
input_tokens = [summarization_word_to_index.get(word.lower(), 1) for word in text.split()]
input_tensor = torch.tensor([input_tokens], dtype=torch.long).to(device)
with torch.no_grad():
summary_ids = summarization_model.generate(input_tensor, num_beams=4, max_length=100, early_stopping=True)
summary_text = summarization_model.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
with open(output_path, "w") as file:
file.write(summary_text)
return output_path
def summarization_api():
data = request.get_json()
text = data.get('text')
if not text:
return jsonify({"error": "Text is required"}), 400
output_file = summarize_text(text)
if output_file == "Summarization model not initialized.":
return jsonify({"error": "Summarization failed"}), 500
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")