kyserS09 commited on
Commit
3d097bd
·
verified ·
1 Parent(s): bbb8ec0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -20
app.py CHANGED
@@ -1,34 +1,125 @@
1
  import gradio as gr
2
- from transformers import LEDTokenizer, LEDForConditionalGeneration
 
 
 
 
 
 
 
 
3
 
4
- # Use Longformer Encoder-Decoder (LED) model
5
- model_name = "allenai/led-large-16384"
6
- tokenizer = LEDTokenizer.from_pretrained(model_name)
7
- model = LEDForConditionalGeneration.from_pretrained(model_name)
8
 
9
- def summarize_text(text):
10
- # Tokenize input with truncation to fit within 16,384 tokens
11
- inputs = tokenizer([text], max_length=16384, return_tensors="pt", truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Generate summary with adjusted parameters
14
- summary_ids = model.generate(
15
- inputs["input_ids"],
16
- num_beams=4,
17
- max_length=512, # Can be adjusted based on summary size needs
18
- min_length=100,
19
- early_stopping=True
20
- )
21
 
22
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
23
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Gradio Interface
26
  iface = gr.Interface(
27
  fn=summarize_text,
28
  inputs="text",
29
  outputs="text",
30
- title="Longformer Summarizer",
31
- description="Enter text to get a summary using the Longformer Encoder-Decoder."
32
  )
33
 
34
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import os
3
+ import requests
4
+ import torch
5
+ from transformers import (
6
+ LEDTokenizer, LEDForConditionalGeneration,
7
+ BartTokenizer, BartForConditionalGeneration,
8
+ PegasusTokenizer, PegasusForConditionalGeneration,
9
+ AutoTokenizer, AutoModelForSeq2SeqLM
10
+ )
11
 
12
+ # OpenAI API Key
13
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Ensure this is set in your environment variables
 
 
14
 
15
+ # List of models in priority order
16
+ MODELS = [
17
+ {
18
+ "name": "allenai/led-large-16384",
19
+ "tokenizer_class": LEDTokenizer,
20
+ "model_class": LEDForConditionalGeneration
21
+ },
22
+ {
23
+ "name": "facebook/bart-large-cnn",
24
+ "tokenizer_class": BartTokenizer,
25
+ "model_class": BartForConditionalGeneration
26
+ },
27
+ {
28
+ "name": "Falconsai/text_summarization",
29
+ "tokenizer_class": AutoTokenizer,
30
+ "model_class": AutoModelForSeq2SeqLM
31
+ },
32
+ {
33
+ "name": "google/pegasus-xsum",
34
+ "tokenizer_class": PegasusTokenizer,
35
+ "model_class": PegasusForConditionalGeneration
36
+ }
37
+ ]
38
+
39
+ # Load models sequentially
40
+ loaded_models = []
41
+ for model_info in MODELS:
42
+ try:
43
+ tokenizer = model_info["tokenizer_class"].from_pretrained(model_info["name"])
44
+ model = model_info["model_class"].from_pretrained(model_info["name"])
45
+ loaded_models.append({"name": model_info["name"], "tokenizer": tokenizer, "model": model})
46
+ print(f"Loaded model: {model_info['name']}")
47
+ except Exception as e:
48
+ print(f"Failed to load {model_info['name']}: {e}")
49
+
50
+ def summarize_with_transformers(text):
51
+ """
52
+ Try summarizing with locally loaded Transformer models in order of priority.
53
+ """
54
+ for model_data in loaded_models:
55
+ try:
56
+ tokenizer = model_data["tokenizer"]
57
+ model = model_data["model"]
58
+
59
+ # Tokenize input with truncation
60
+ inputs = tokenizer([text], max_length=16384, return_tensors="pt", truncation=True)
61
+
62
+ # Generate summary
63
+ summary_ids = model.generate(
64
+ inputs["input_ids"],
65
+ num_beams=4,
66
+ max_length=512,
67
+ min_length=100,
68
+ early_stopping=True
69
+ )
70
 
71
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
72
+ return summary # Return the first successful response
 
 
 
 
 
 
73
 
74
+ except Exception as e:
75
+ print(f"Error using {model_data['name']}: {e}")
76
+
77
+ return None # Indicate failure
78
+
79
+ def summarize_with_chatgpt(text):
80
+ """
81
+ Fallback to OpenAI ChatGPT API if all other models fail.
82
+ """
83
+ if not OPENAI_API_KEY:
84
+ return "Error: No OpenAI API key provided."
85
+
86
+ headers = {
87
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
88
+ "Content-Type": "application/json"
89
+ }
90
+
91
+ payload = {
92
+ "model": "gpt-3.5-turbo",
93
+ "messages": [{"role": "user", "content": f"Summarize this article: {text}"}],
94
+ "max_tokens": 512
95
+ }
96
+
97
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
98
+
99
+ if response.status_code == 200:
100
+ return response.json()["choices"][0]["message"]["content"]
101
+ else:
102
+ return f"Error: Failed to summarize with ChatGPT (status {response.status_code})"
103
+
104
+ def summarize_text(text):
105
+ """
106
+ Main function to summarize text, trying Transformer models first, then ChatGPT if needed.
107
+ """
108
+ summary = summarize_with_transformers(text)
109
+
110
+ if summary:
111
+ return summary # Return successful summary from a Transformer model
112
+
113
+ print("All Transformer models failed. Falling back to ChatGPT...")
114
+ return summarize_with_chatgpt(text) # Use ChatGPT as last resort
115
 
116
  # Gradio Interface
117
  iface = gr.Interface(
118
  fn=summarize_text,
119
  inputs="text",
120
  outputs="text",
121
+ title="Multi-Model Summarizer with Fallback",
122
+ description="Tries multiple models for summarization, falling back to ChatGPT if needed."
123
  )
124
 
125
  if __name__ == "__main__":