Update app.py
Browse files
app.py
CHANGED
@@ -75,14 +75,16 @@ generation_params = {
|
|
75 |
"use_cache": False
|
76 |
}
|
77 |
|
|
|
78 |
def generate_synthetic_data(description, columns):
|
79 |
try:
|
80 |
-
# Prepare the input for the Llama model
|
81 |
formatted_prompt = format_prompt(description, columns)
|
82 |
|
83 |
# Tokenize the prompt with truncation enabled
|
84 |
-
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
|
85 |
-
|
|
|
|
|
86 |
|
87 |
# Generate synthetic data
|
88 |
with torch.no_grad():
|
@@ -94,15 +96,19 @@ def generate_synthetic_data(description, columns):
|
|
94 |
num_return_sequences=1,
|
95 |
)
|
96 |
|
|
|
|
|
|
|
|
|
97 |
# Decode the generated output
|
98 |
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
|
99 |
-
|
100 |
# Return the generated synthetic data
|
101 |
return generated_text
|
102 |
except Exception as e:
|
103 |
-
st.error(f"Error in generate_synthetic_data: {e}")
|
104 |
return f"Error: {e}"
|
105 |
|
|
|
106 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
107 |
data_frames = []
|
108 |
num_iterations = num_rows // rows_per_generation
|
|
|
75 |
"use_cache": False
|
76 |
}
|
77 |
|
78 |
+
# Generate synthetic data
|
79 |
def generate_synthetic_data(description, columns):
|
80 |
try:
|
|
|
81 |
formatted_prompt = format_prompt(description, columns)
|
82 |
|
83 |
# Tokenize the prompt with truncation enabled
|
84 |
+
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
|
85 |
+
|
86 |
+
# Move inputs to the correct device
|
87 |
+
inputs = {k: v.to(model_llama.device) for k, v in inputs.items()}
|
88 |
|
89 |
# Generate synthetic data
|
90 |
with torch.no_grad():
|
|
|
96 |
num_return_sequences=1,
|
97 |
)
|
98 |
|
99 |
+
# Check for meta tensor before decoding
|
100 |
+
if outputs.is_meta:
|
101 |
+
raise ValueError("Output tensor is in meta state, check model and input.")
|
102 |
+
|
103 |
# Decode the generated output
|
104 |
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
|
105 |
+
|
106 |
# Return the generated synthetic data
|
107 |
return generated_text
|
108 |
except Exception as e:
|
|
|
109 |
return f"Error: {e}"
|
110 |
|
111 |
+
|
112 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
113 |
data_frames = []
|
114 |
num_iterations = num_rows // rows_per_generation
|