Update app.py
Browse files
app.py
CHANGED
@@ -30,12 +30,11 @@ text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokeniz
|
|
30 |
# Load the Llama-3 model and tokenizer once during startup
|
31 |
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
|
32 |
model_llama = AutoModelForCausalLM.from_pretrained(
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
# Define your prompt template
|
41 |
prompt_template = """\
|
@@ -80,11 +79,11 @@ def generate_synthetic_data(description, columns):
|
|
80 |
try:
|
81 |
# Prepare the input for the Llama model
|
82 |
formatted_prompt = format_prompt(description, columns)
|
83 |
-
|
84 |
# Tokenize the prompt with truncation enabled
|
85 |
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
|
86 |
print(f"Input Tensor Size: {inputs['input_ids'].size()}")
|
87 |
-
|
88 |
# Generate synthetic data
|
89 |
with torch.no_grad():
|
90 |
outputs = model_llama.generate(
|
@@ -94,10 +93,10 @@ def generate_synthetic_data(description, columns):
|
|
94 |
temperature=generation_params["temperature"],
|
95 |
num_return_sequences=1,
|
96 |
)
|
97 |
-
|
98 |
# Decode the generated output
|
99 |
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
|
100 |
-
|
101 |
# Return the generated synthetic data
|
102 |
return generated_text
|
103 |
except Exception as e:
|
@@ -110,20 +109,20 @@ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_
|
|
110 |
|
111 |
# Create a progress bar
|
112 |
progress_bar = st.progress(0)
|
113 |
-
|
114 |
for i in tqdm(range(num_iterations)):
|
115 |
generated_data = generate_synthetic_data(description, columns)
|
116 |
-
|
117 |
-
print("Generated Data:\n", generated_data) # Add this line to debug
|
118 |
|
119 |
if "Error" in generated_data:
|
120 |
return generated_data
|
|
|
121 |
df_synthetic = process_generated_data(generated_data)
|
122 |
data_frames.append(df_synthetic)
|
123 |
-
|
124 |
# Update the progress bar
|
125 |
progress_bar.progress((i + 1) / num_iterations)
|
126 |
-
|
127 |
return pd.concat(data_frames, ignore_index=True)
|
128 |
|
129 |
def process_generated_data(csv_data):
|
@@ -137,7 +136,6 @@ def process_generated_data(csv_data):
|
|
137 |
print("DataFrame Shape:", df.shape)
|
138 |
print("DataFrame Head:\n", df.head())
|
139 |
|
140 |
-
|
141 |
# Check if the DataFrame is empty
|
142 |
if df.empty:
|
143 |
raise ValueError("Generated DataFrame is empty.")
|
@@ -157,7 +155,7 @@ if st.button("Generate"):
|
|
157 |
description = description.strip()
|
158 |
columns = [col.strip() for col in columns.split(',')]
|
159 |
df_synthetic = generate_large_synthetic_data(description, columns)
|
160 |
-
|
161 |
if isinstance(df_synthetic, str) and "Error" in df_synthetic:
|
162 |
st.error(df_synthetic) # Display error message if any
|
163 |
else:
|
|
|
30 |
# Load the Llama-3 model and tokenizer once during startup
|
31 |
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
|
32 |
model_llama = AutoModelForCausalLM.from_pretrained(
|
33 |
+
"meta-llama/Meta-Llama-3.1-8B",
|
34 |
+
torch_dtype='auto',
|
35 |
+
device_map='auto',
|
36 |
+
token=hf_token
|
37 |
+
)
|
|
|
38 |
|
39 |
# Define your prompt template
|
40 |
prompt_template = """\
|
|
|
79 |
try:
|
80 |
# Prepare the input for the Llama model
|
81 |
formatted_prompt = format_prompt(description, columns)
|
82 |
+
|
83 |
# Tokenize the prompt with truncation enabled
|
84 |
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
|
85 |
print(f"Input Tensor Size: {inputs['input_ids'].size()}")
|
86 |
+
|
87 |
# Generate synthetic data
|
88 |
with torch.no_grad():
|
89 |
outputs = model_llama.generate(
|
|
|
93 |
temperature=generation_params["temperature"],
|
94 |
num_return_sequences=1,
|
95 |
)
|
96 |
+
|
97 |
# Decode the generated output
|
98 |
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
|
99 |
+
|
100 |
# Return the generated synthetic data
|
101 |
return generated_text
|
102 |
except Exception as e:
|
|
|
109 |
|
110 |
# Create a progress bar
|
111 |
progress_bar = st.progress(0)
|
112 |
+
|
113 |
for i in tqdm(range(num_iterations)):
|
114 |
generated_data = generate_synthetic_data(description, columns)
|
115 |
+
print("Generated Data:\n", generated_data) # Move the print statement here
|
|
|
116 |
|
117 |
if "Error" in generated_data:
|
118 |
return generated_data
|
119 |
+
|
120 |
df_synthetic = process_generated_data(generated_data)
|
121 |
data_frames.append(df_synthetic)
|
122 |
+
|
123 |
# Update the progress bar
|
124 |
progress_bar.progress((i + 1) / num_iterations)
|
125 |
+
|
126 |
return pd.concat(data_frames, ignore_index=True)
|
127 |
|
128 |
def process_generated_data(csv_data):
|
|
|
136 |
print("DataFrame Shape:", df.shape)
|
137 |
print("DataFrame Head:\n", df.head())
|
138 |
|
|
|
139 |
# Check if the DataFrame is empty
|
140 |
if df.empty:
|
141 |
raise ValueError("Generated DataFrame is empty.")
|
|
|
155 |
description = description.strip()
|
156 |
columns = [col.strip() for col in columns.split(',')]
|
157 |
df_synthetic = generate_large_synthetic_data(description, columns)
|
158 |
+
|
159 |
if isinstance(df_synthetic, str) and "Error" in df_synthetic:
|
160 |
st.error(df_synthetic) # Display error message if any
|
161 |
else:
|