yakine commited on
Commit
d61c33a
·
verified ·
1 Parent(s): 6dbe5c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -30,12 +30,11 @@ text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokeniz
30
  # Load the Llama-3 model and tokenizer once during startup
31
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
32
  model_llama = AutoModelForCausalLM.from_pretrained(
33
- "meta-llama/Meta-Llama-3.1-8B",
34
- torch_dtype='auto',
35
- device_map='auto',
36
- token=hf_token
37
- )
38
-
39
 
40
  # Define your prompt template
41
  prompt_template = """\
@@ -80,11 +79,11 @@ def generate_synthetic_data(description, columns):
80
  try:
81
  # Prepare the input for the Llama model
82
  formatted_prompt = format_prompt(description, columns)
83
-
84
  # Tokenize the prompt with truncation enabled
85
  inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
86
  print(f"Input Tensor Size: {inputs['input_ids'].size()}")
87
-
88
  # Generate synthetic data
89
  with torch.no_grad():
90
  outputs = model_llama.generate(
@@ -94,10 +93,10 @@ def generate_synthetic_data(description, columns):
94
  temperature=generation_params["temperature"],
95
  num_return_sequences=1,
96
  )
97
-
98
  # Decode the generated output
99
  generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
100
-
101
  # Return the generated synthetic data
102
  return generated_text
103
  except Exception as e:
@@ -110,20 +109,20 @@ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_
110
 
111
  # Create a progress bar
112
  progress_bar = st.progress(0)
113
-
114
  for i in tqdm(range(num_iterations)):
115
  generated_data = generate_synthetic_data(description, columns)
116
-
117
- print("Generated Data:\n", generated_data) # Add this line to debug
118
 
119
  if "Error" in generated_data:
120
  return generated_data
 
121
  df_synthetic = process_generated_data(generated_data)
122
  data_frames.append(df_synthetic)
123
-
124
  # Update the progress bar
125
  progress_bar.progress((i + 1) / num_iterations)
126
-
127
  return pd.concat(data_frames, ignore_index=True)
128
 
129
  def process_generated_data(csv_data):
@@ -137,7 +136,6 @@ def process_generated_data(csv_data):
137
  print("DataFrame Shape:", df.shape)
138
  print("DataFrame Head:\n", df.head())
139
 
140
-
141
  # Check if the DataFrame is empty
142
  if df.empty:
143
  raise ValueError("Generated DataFrame is empty.")
@@ -157,7 +155,7 @@ if st.button("Generate"):
157
  description = description.strip()
158
  columns = [col.strip() for col in columns.split(',')]
159
  df_synthetic = generate_large_synthetic_data(description, columns)
160
-
161
  if isinstance(df_synthetic, str) and "Error" in df_synthetic:
162
  st.error(df_synthetic) # Display error message if any
163
  else:
 
30
  # Load the Llama-3 model and tokenizer once during startup
31
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
32
  model_llama = AutoModelForCausalLM.from_pretrained(
33
+ "meta-llama/Meta-Llama-3.1-8B",
34
+ torch_dtype='auto',
35
+ device_map='auto',
36
+ token=hf_token
37
+ )
 
38
 
39
  # Define your prompt template
40
  prompt_template = """\
 
79
  try:
80
  # Prepare the input for the Llama model
81
  formatted_prompt = format_prompt(description, columns)
82
+
83
  # Tokenize the prompt with truncation enabled
84
  inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
85
  print(f"Input Tensor Size: {inputs['input_ids'].size()}")
86
+
87
  # Generate synthetic data
88
  with torch.no_grad():
89
  outputs = model_llama.generate(
 
93
  temperature=generation_params["temperature"],
94
  num_return_sequences=1,
95
  )
96
+
97
  # Decode the generated output
98
  generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
99
+
100
  # Return the generated synthetic data
101
  return generated_text
102
  except Exception as e:
 
109
 
110
  # Create a progress bar
111
  progress_bar = st.progress(0)
112
+
113
  for i in tqdm(range(num_iterations)):
114
  generated_data = generate_synthetic_data(description, columns)
115
+ print("Generated Data:\n", generated_data) # Move the print statement here
 
116
 
117
  if "Error" in generated_data:
118
  return generated_data
119
+
120
  df_synthetic = process_generated_data(generated_data)
121
  data_frames.append(df_synthetic)
122
+
123
  # Update the progress bar
124
  progress_bar.progress((i + 1) / num_iterations)
125
+
126
  return pd.concat(data_frames, ignore_index=True)
127
 
128
  def process_generated_data(csv_data):
 
136
  print("DataFrame Shape:", df.shape)
137
  print("DataFrame Head:\n", df.head())
138
 
 
139
  # Check if the DataFrame is empty
140
  if df.empty:
141
  raise ValueError("Generated DataFrame is empty.")
 
155
  description = description.strip()
156
  columns = [col.strip() for col in columns.split(',')]
157
  df_synthetic = generate_large_synthetic_data(description, columns)
158
+
159
  if isinstance(df_synthetic, str) and "Error" in df_synthetic:
160
  st.error(df_synthetic) # Display error message if any
161
  else: