yakine commited on
Commit
52cec89
·
verified ·
1 Parent(s): d61c33a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -75,14 +75,16 @@ generation_params = {
75
  "use_cache": False
76
  }
77
 
 
78
  def generate_synthetic_data(description, columns):
79
  try:
80
- # Prepare the input for the Llama model
81
  formatted_prompt = format_prompt(description, columns)
82
 
83
  # Tokenize the prompt with truncation enabled
84
- inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
85
- print(f"Input Tensor Size: {inputs['input_ids'].size()}")
 
 
86
 
87
  # Generate synthetic data
88
  with torch.no_grad():
@@ -94,15 +96,19 @@ def generate_synthetic_data(description, columns):
94
  num_return_sequences=1,
95
  )
96
 
 
 
 
 
97
  # Decode the generated output
98
  generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
99
-
100
  # Return the generated synthetic data
101
  return generated_text
102
  except Exception as e:
103
- st.error(f"Error in generate_synthetic_data: {e}")
104
  return f"Error: {e}"
105
 
 
106
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
107
  data_frames = []
108
  num_iterations = num_rows // rows_per_generation
 
75
  "use_cache": False
76
  }
77
 
78
+ # Generate synthetic data
79
  def generate_synthetic_data(description, columns):
80
  try:
 
81
  formatted_prompt = format_prompt(description, columns)
82
 
83
  # Tokenize the prompt with truncation enabled
84
+ inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
85
+
86
+ # Move inputs to the correct device
87
+ inputs = {k: v.to(model_llama.device) for k, v in inputs.items()}
88
 
89
  # Generate synthetic data
90
  with torch.no_grad():
 
96
  num_return_sequences=1,
97
  )
98
 
99
+ # Check for meta tensor before decoding
100
+ if outputs.is_meta:
101
+ raise ValueError("Output tensor is in meta state, check model and input.")
102
+
103
  # Decode the generated output
104
  generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
105
+
106
  # Return the generated synthetic data
107
  return generated_text
108
  except Exception as e:
 
109
  return f"Error: {e}"
110
 
111
+
112
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
113
  data_frames = []
114
  num_iterations = num_rows // rows_per_generation