yakine commited on
Commit
16360f1
·
verified ·
1 Parent(s): 6bc7f9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -119
app.py CHANGED
@@ -1,143 +1,136 @@
1
  from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import StreamingResponse, JSONResponse
3
  from pydantic import BaseModel
4
- import pandas as pd
5
- import os
6
- import requests
7
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline
8
- from io import StringIO
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from huggingface_hub import HfFolder
11
- from tqdm import tqdm
12
 
13
  app = FastAPI()
14
 
 
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"],
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
- hf_token = os.getenv('HF_API_TOKEN')
24
- if not hf_token:
25
- raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
26
-
27
-
28
- # Load GPT-2 model and tokenizer
29
- tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')
30
- model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
31
-
32
- # Create a pipeline for text generation using GPT-2
33
- text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
34
-
35
-
36
- # Define prompt template for generating the dataset
37
- prompt_template = """\
38
- You are an AI designed exclusively for generating synthetic tabular datasets.
39
- Task: Generate a synthetic dataset based on the provided description and column names.
40
- Description: {description}
41
- Columns: {columns}
42
- Instructions:
43
- - Output only the tabular data in valid CSV format.
44
- - Include the header row followed strictly by the data rows.
45
- - Do not include any additional text, explanations, comments, or code outside of the CSV data.
46
- - Ensure that the values for each column are contextually appropriate based on the description.
47
- - Do not alter the column names or add any new columns.
48
- - Each row must contain data for all columns without any empty values unless specified.
49
- - Format Example (do not include this line or the following example in your output):
50
- Column1,Column2,Column3
51
- Value1,Value2,Value3
52
- Value4,Value5,Value6
 
53
  """
54
 
 
 
 
55
 
56
- # Define generation parameters
57
- generation_params = {
58
- "top_p": 0.90,
59
- "temperature": 0.8,
60
- "max_new_tokens": 2048,
61
- "return_full_text": False,
62
- "use_cache": False
63
- }
64
-
65
- def format_prompt(description, columns):
66
- prompt = prompt_template.format(description=description, columns=",".join(columns))
67
- return prompt
68
-
69
- def generate_synthetic_data(description, columns):
70
- formatted_prompt = format_prompt(description, columns)
71
- payload = {"inputs": formatted_prompt, "parameters": generation_params}
72
 
73
- # Call Mixtral model to generate data
74
- response = requests.post("https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
75
- headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
76
-
77
- if response.status_code == 200:
78
- return response.json()[0]["generated_text"]
79
- else:
80
- print(f"Error generating data: {response.status_code}, {response.text}")
81
- return None
82
-
83
- def process_generated_data(csv_data):
84
  try:
85
- # Ensure the data is cleaned and correctly formatted
86
- cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
87
- data = StringIO(cleaned_data)
88
-
89
- # Read the CSV data with specific parameters to handle irregularities
90
- df = pd.read_csv(data)
91
-
92
- return df
93
- except pd.errors.ParserError as e:
94
- print(f"Failed to parse CSV data: {e}")
95
- return None
96
-
97
- def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
98
- data_frames = []
99
-
100
- for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
101
- generated_data = generate_synthetic_data(description, columns)
102
 
103
- if generated_data:
104
-
105
- df_synthetic = process_generated_data(generated_data)
106
-
107
- if df_synthetic is not None and not df_synthetic.empty:
108
- data_frames.append(df_synthetic)
109
- else:
110
- print("Skipping invalid generation.")
111
-
112
- if data_frames:
113
- return pd.concat(data_frames, ignore_index=True)
114
- else:
115
- print("No valid data frames to concatenate.")
116
- return pd.DataFrame(columns=columns)
117
-
118
- class DataGenerationRequest(BaseModel):
119
- description: str
120
- columns: list[str]
121
-
122
- @app.post("/generate/")
123
- def generate_data(request: DataGenerationRequest):
124
- description = request.description.strip()
125
- columns = [col.strip() for col in request.columns]
126
- csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- if csv_data.empty:
129
- return JSONResponse(content={"error": "No valid data generated"}, status_code=500)
 
 
 
 
 
 
130
 
131
- csv_buffer = StringIO()
132
- csv_data.to_csv(csv_buffer, index=False)
133
- csv_buffer.seek(0)
134
 
135
- return StreamingResponse(
136
- csv_buffer,
137
- media_type="text/csv",
138
- headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
139
- )
 
 
 
 
 
 
 
 
 
 
 
140
 
 
 
 
141
  @app.get("/")
142
  def greet_json():
143
- return {"Hello": "World!"}
 
1
  from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
5
+ import logging
6
+ import re
 
 
 
 
7
 
8
  app = FastAPI()
9
 
10
+ # Enable CORS if needed
11
+ from fastapi.middleware.cors import CORSMiddleware
12
  app.add_middleware(
13
  CORSMiddleware,
14
+ allow_origins=["*"], # In production, restrict this to your frontend URL
15
  allow_credentials=True,
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+ ####################################
24
+ # Text Generation Endpoint
25
+ ####################################
26
+
27
+ TEXT_MODEL_NAME = "aubmindlab/aragpt2-medium"
28
+ text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
29
+ text_model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_NAME)
30
+
31
+ general_prompt_template = """
32
+ أنت الآن نموذج لغة مخصص لتوليد نصوص عربية تعليمية بناءً على المادة والمستوى التعليمي. سيتم إعطاؤك مادة تعليمية ومستوى تعليمي، وعليك إنشاء نص مناسب بناءً على ذلك. النص يجب أن يكون:
33
+ 1. واضحًا وسهل الفهم.
34
+ 2. مناسبًا للمستوى التعليمي المحدد.
35
+ 3. مرتبطًا بالمادة التعليمية المطلوبة.
36
+ 4. قصيرًا ومباشرًا.
37
+ ### أمثلة:
38
+ 1. المادة: العلوم
39
+ المستوى: الابتدائي
40
+ النص: النباتات كائنات حية تحتاج إلى الماء والهواء وضوء الشمس لتنمو. بعض النباتات تنتج أزهارًا وفواكه. النباتات تساعدنا في الحصول على الأكسجين.
41
+ 2. المادة: التاريخ
42
+ المستوى: المتوسط
43
+ النص: التاريخ هو دراسة الماضي وأحداثه المهمة. من خلال التاريخ، نتعلم عن الحضارات القديمة مثل الحضارة الفرعونية والحضارة الإسلامية. التاريخ يساعدنا على فهم تطور البشرية.
44
+ 3. المادة: الجغرافيا
45
+ المستوى: المتوسط
46
+ النص: الجغرافيا هي دراسة الأرض وخصائصها. نتعلم عن القارات والمحيطات والجبال. الجغرافيا تساعدنا على فهم العالم من حولنا.
47
+ ---
48
+ المادة: {المادة}
49
+ المستوى: {المستوى}
50
+ اكتب نصًا مناسبًا بناءً على المادة والمستوى المحددين. ركّز على جعل النص بسيطًا وواضحًا للمستوى المطلوب.
51
  """
52
 
53
+ class GenerateTextRequest(BaseModel):
54
+ المادة: str
55
+ المستوى: str
56
 
57
+ @app.post("/generate-text")
58
+ def generate_text(request: GenerateTextRequest):
59
+ if not request.المادة or not request.المستوى:
60
+ raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
+ prompt = general_prompt_template.format(المادة=request.المادة, المستوى=request.المستوى)
64
+ inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ with torch.no_grad():
67
+ outputs = text_model.generate(
68
+ inputs.input_ids,
69
+ max_length=300,
70
+ num_return_sequences=1,
71
+ temperature=0.7,
72
+ top_p=0.95,
73
+ do_sample=True,
74
+ )
75
+
76
+ generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
77
+ logger.info(f"Generated text: {generated_text}")
78
+ return {"generated_text": generated_text}
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error during text generation: {str(e)}")
82
+ raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
83
+
84
+ ####################################
85
+ # Question & Answer Generation Model
86
+ ####################################
87
+ QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
88
+ qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
89
+ qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
90
+
91
+ def extract_answer(context: str) -> str:
92
+ """Extract the first sentence (or a key phrase) from the context."""
93
+ sentences = re.split(r'[.!؟]', context)
94
+ sentences = [s.strip() for s in sentences if s.strip()]
95
+ return sentences[0] if sentences else context
96
+
97
+ def get_question(context: str, answer: str) -> str:
98
+ """Generate a question based on the context and the candidate answer."""
99
+ text = f"النص: {context} الإجابة: {answer} </s>"
100
+ text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt")
101
 
102
+ with torch.no_grad():
103
+ generated_ids = qa_model.generate(
104
+ input_ids=text_encoding['input_ids'],
105
+ attention_mask=text_encoding['attention_mask'],
106
+ max_length=64,
107
+ num_beams=5,
108
+ num_return_sequences=1
109
+ )
110
 
111
+ question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True).replace("question:", "").strip()
112
+ return question
 
113
 
114
+ class GenerateQARequest(BaseModel):
115
+ text: str
116
+
117
+ @app.post("/generate-qa")
118
+ def generate_qa(request: GenerateQARequest):
119
+ if not request.text:
120
+ raise HTTPException(status_code=400, detail="Text is required.")
121
+
122
+ try:
123
+ question, answer = get_question(request.text, extract_answer(request.text))
124
+ logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
125
+ return {"question": question, "answer": answer}
126
+
127
+ except Exception as e:
128
+ logger.error(f"Error during QA generation: {str(e)}")
129
+ raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
130
 
131
+ ####################################
132
+ # Root Endpoint
133
+ ####################################
134
  @app.get("/")
135
  def greet_json():
136
+ return {"message": "Welcome to the Arabic Text Generation API!"}