Files changed (4) hide show
  1. app.py +196 -146
  2. concatenated_output.csv +1 -1
  3. utilities.py +97 -1
  4. utils_get_db_tables_info.py +12 -6
app.py CHANGED
@@ -12,6 +12,7 @@ import plotly.colors as pc
12
  from qatch.connectors.sqlite_connector import SqliteConnector
13
  from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
14
  from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
 
15
  from prediction import ModelPrediction
16
  import utils_get_db_tables_info
17
  import utilities as us
@@ -31,7 +32,6 @@ import utilities as us
31
  #pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
32
  pnp_path = "concatenated_output.csv"
33
  PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
34
-
35
  js_func = """
36
  function refresh() {
37
  const url = new URL(window.location);
@@ -42,7 +42,8 @@ function refresh() {
42
  }
43
  }
44
  """
45
- reset_flag=False
 
46
 
47
  with open('style.css', 'r') as file:
48
  css = file.read()
@@ -65,6 +66,8 @@ description = """## πŸ“Š Comparison of Proprietary and Non-Proprietary Databases
65
  ### ➀ **Non-Proprietary**
66
  ###     β‡’ Spider 1.0 πŸ•·οΈ"""
67
  prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
 
 
68
 
69
  input_data = {
70
  'input_method': "",
@@ -93,6 +96,7 @@ def load_data(file, path, use_default):
93
  #change path
94
  input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
95
  input_data["data"] = us.load_data(file, input_data["db_name"])
 
96
  df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
97
  if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
98
  table2primary_key = {}
@@ -317,7 +321,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
317
 
318
  # Model selection button (initially disabled)
319
  open_model_selection = gr.Button("Choose your models", interactive=False)
320
-
321
  def update_table_list(data):
322
  """Dynamically updates the list of available tables and excluded ones."""
323
  if isinstance(data, dict) and data:
@@ -458,9 +461,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
458
  default_checkbox
459
  ]
460
  )
461
-
462
  reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
463
-
464
 
465
  ####################################
466
  # MODEL SELECTION PART #
@@ -506,10 +507,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
506
  # Function to get selected models
507
  def get_selected_models(*model_selections):
508
  selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
509
-
510
  input_data['models'] = selected_models
511
  button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
512
- return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
513
 
514
  # Add the Textbox to the interface
515
  prompt = gr.TextArea(
@@ -517,17 +517,19 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
517
  placeholder=prompt_default,
518
  elem_id="custom-textarea"
519
  )
 
520
  warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
521
 
522
  # Submit button (initially disabled)
523
-
524
- submit_models_button = gr.Button("Submit Models", interactive=False)
 
525
 
526
  def check_prompt(prompt):
527
  #TODO
528
  missing_elements = []
529
  if(prompt==""):
530
- input_data["prompt"]=prompt_default
531
  button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
532
  else:
533
  input_data["prompt"]=prompt
@@ -544,18 +546,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
544
  ), gr.update(interactive=button_state)
545
  return gr.update(visible=False), gr.update(interactive=button_state)
546
 
547
- prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
548
  # Link checkboxes to selection events
549
  for checkbox in model_checkboxes:
550
  checkbox.change(
551
  fn=get_selected_models,
552
  inputs=model_checkboxes,
553
- outputs=[selected_models_output, select_model_acc, submit_models_button]
554
  )
555
  prompt.change(
556
  fn=get_selected_models,
557
  inputs=model_checkboxes,
558
- outputs=[selected_models_output, select_model_acc, submit_models_button]
559
  )
560
 
561
  submit_models_button.click(
@@ -564,6 +566,17 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
564
  outputs=[selected_models_output, select_model_acc, qatch_acc]
565
  )
566
 
 
 
 
 
 
 
 
 
 
 
 
567
  def enable_disable(enable):
568
  return (
569
  *[gr.update(interactive=enable) for _ in model_checkboxes],
@@ -574,6 +587,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
574
  gr.update(interactive=enable),
575
  gr.update(interactive=enable),
576
  *[gr.update(interactive=enable) for _ in table_outputs],
 
577
  gr.update(interactive=enable)
578
  )
579
 
@@ -591,7 +605,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
591
  default_checkbox,
592
  table_selector,
593
  *table_outputs,
594
- open_model_selection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  ]
596
  )
597
 
@@ -609,7 +640,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
609
  default_checkbox,
610
  table_selector,
611
  *table_outputs,
612
- open_model_selection
 
613
  ]
614
  )
615
 
@@ -660,9 +692,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
660
  {mirrored_symbols}
661
  </div>
662
  """
663
- def qatch_flow():
664
- #caching
665
  global reset_flag
 
666
  predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
667
  metrics_conc = pd.DataFrame()
668
  columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
@@ -692,7 +725,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
692
  </div>
693
  """
694
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
695
- #time.sleep(0.02)
696
  prediction = row['predicted_sql']
697
 
698
  display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
@@ -700,22 +733,25 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
700
  <div style='font-size: 3rem'>➑️</div>
701
  <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
702
  </div>
703
- """
 
704
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
705
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
706
  metrics_conc = target_df
707
- if 'valid_efficiency_score' not in metrics_conc.columns:
708
- metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
709
  eval_text = generate_eval_text("End evaluation")
710
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
 
711
  else:
712
-
713
  orchestrator_generator = OrchestratorGenerator()
714
- # TODO: add to target_df column target_df["columns_used"], tables selection
715
- # print(input_data['data']['db'])
716
- #print(input_data['data']['selected_tables'])
717
- target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
718
- #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
 
 
719
 
720
  predictor = ModelPrediction()
721
  reset_flag = False
@@ -736,15 +772,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
736
  </div>
737
  """
738
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
739
- start_time = time.time()
740
- samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
741
-
742
- schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
743
- db_id = input_data["db_name"],
744
- base_path = input_data["data_path"],
745
- normalize=False,
746
- sql=row["query"],
747
- get_insert_into=True
 
 
 
748
  )
749
 
750
  #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
@@ -752,19 +791,27 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
752
  #PREDICTION SQL
753
 
754
  # TODO add button for QA or SP and pass to .make_prediction parameter TASK
 
 
 
755
  response = predictor.make_prediction(
756
  question=question,
757
- db_schema=schema_text,
758
  model_name=model,
759
  prompt=f"{prompt_to_send}",
760
- task="SP" # TODO change accordingly
761
  )
762
  prediction = response['response_parsed']
763
  price = response['cost']
764
  answer = response['response']
765
 
766
  end_time = time.time()
767
- display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
 
 
 
 
 
768
  <div style='display: flex; align-items: center;'>
769
  <div style='font-size: 3rem'>➑️</div>
770
  <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
@@ -779,40 +826,47 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
779
  'query': row["query"],
780
  'db_path': input_data["data_path"],
781
  'price':price,
782
- 'answer':answer,
783
  'number_question':count,
784
- 'prompt': prompt_to_send
 
785
  }]).dropna(how="all") # Remove only completely empty rows
786
  count=count+1
787
  # TODO: use a for loop
 
 
788
  for col in target_df.columns:
789
  if col not in new_row.columns:
790
  new_row[col] = row[col]
791
-
792
  # Update model's prediction dataframe incrementally
793
  if not new_row.empty:
794
  predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
795
 
796
  # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
797
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
798
-
799
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
800
  # END
801
  eval_text = generate_eval_text("Evaluation")
802
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
 
803
  evaluator = OrchestratorEvaluator()
 
804
  for model in input_data["models"]:
805
- metrics_df_model = evaluator.evaluate_df(
806
- df=predictions_dict[model],
807
- target_col_name="query",
808
- prediction_col_name="predicted_sql",
809
- db_path_name="db_path"
810
- )
 
 
 
811
  metrics_df_model['model'] = model
812
  metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
813
-
814
- if 'valid_efficiency_score' not in metrics_conc.columns:
815
- metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
 
816
  eval_text = generate_eval_text("End evaluation")
817
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
818
 
@@ -848,6 +902,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
848
  gr.Markdown(f"**Results for {model}**")
849
  tab_dict[model] = tab
850
  dataframe_per_model[model] = gr.DataFrame()
 
851
  # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
852
 
853
  evaluation_loading = gr.Markdown()
@@ -860,13 +915,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
860
  inputs=[],
861
  outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
862
  )
 
 
 
 
 
863
 
864
  selected_models_display = gr.JSON(label="Final input data", visible=False)
865
  metrics_df = gr.DataFrame(visible=False)
866
  metrics_df_out = gr.DataFrame(visible=False)
867
 
868
  submit_models_button.click(
869
- fn=qatch_flow,
 
 
 
 
 
 
870
  inputs=[],
871
  outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
872
  )
@@ -875,6 +941,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
875
  fn=lambda: gr.update(value=input_data),
876
  outputs=[selected_models_display]
877
  )
 
 
 
 
878
 
879
  # Works for METRICS
880
  metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
@@ -897,10 +967,16 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
897
  fn=lambda: gr.update(visible=False),
898
  outputs=[download_metrics]
899
  )
 
 
 
 
900
 
901
  def refresh():
902
  global reset_flag
 
903
  reset_flag = True
 
904
 
905
  reset_data = gr.Button("Back to upload data section", interactive=True)
906
 
@@ -926,10 +1002,12 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
926
  default_checkbox,
927
  table_selector,
928
  *table_outputs,
929
- open_model_selection
 
930
  ]
931
  )
932
-
 
933
  ##########################################
934
  # METRICS VISUALIZATION SECTION #
935
  ##########################################
@@ -944,8 +1022,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
944
  ####################################
945
 
946
  def load_data_csv_es():
947
-
948
  if input_data["input_method"]=="default":
 
949
  df = pd.read_csv(pnp_path)
950
  df = df[df['model'].isin(input_data["models"])]
951
  df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
@@ -956,6 +1035,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
956
  df['model'] = df['model'].replace('llama-70', 'Llama-70B')
957
  df['model'] = df['model'].replace('llama-8', 'Llama-8B')
958
  df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
 
959
  return df
960
  return metrics_df_out
961
 
@@ -998,20 +1078,21 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
998
 
999
  DB_CATEGORY_COLORS = generate_db_category_colors()
1000
 
1001
- def normalize_valid_efficiency_score(df):
1002
- #TODO valid_efficiency_score
1003
- #print(df['valid_efficiency_score'])
1004
- df['valid_efficiency_score'] = df['valid_efficiency_score'].replace([np.nan, ''], 0)
1005
- df['valid_efficiency_score'] = df['valid_efficiency_score'].astype(int)
1006
- min_val = df['valid_efficiency_score'].min()
1007
- max_val = df['valid_efficiency_score'].max()
1008
 
1009
- if min_val == max_val:
1010
- # Tutti i valori sono uguali, assegna 1.0 a tutto per evitare divisione per zero
1011
- df['valid_efficiency_score'] = 1.0
 
 
 
1012
  else:
1013
- df['valid_efficiency_score'] = (
1014
- df['valid_efficiency_score'] - min_val
1015
  ) / (max_val - min_val)
1016
 
1017
  return df
@@ -1024,7 +1105,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1024
  # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
1025
  def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
1026
  df = df[df['model'].isin(selected_models)]
1027
- df = normalize_valid_efficiency_score(df)
1028
 
1029
  # Mappatura nomi leggibili -> tecnici
1030
  qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
@@ -1141,7 +1222,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1141
  selected_models = [selected_models]
1142
 
1143
  df = df[df['model'].isin(selected_models)]
1144
- df = normalize_valid_efficiency_score(df)
1145
 
1146
  # Converti nomi leggibili -> tecnici
1147
  qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
@@ -1226,54 +1307,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1226
  )
1227
 
1228
  return gr.Plot(fig, visible=True)
1229
-
1230
- """
1231
- def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
1232
- if selected_models == "All":
1233
- selected_models = models
1234
- else:
1235
- selected_models = [selected_models]
1236
-
1237
- df = df[df['model'].isin(selected_models)]
1238
- df = normalize_valid_efficiency_score(df)
1239
-
1240
- if radio_metric == "Qatch":
1241
- selected_metrics = qatch_selected_metrics
1242
- else:
1243
- selected_metrics = external_selected_metric
1244
-
1245
- df = calculate_average_metrics(df, selected_metrics)
1246
-
1247
- # Raggruppamento per modello e categoria
1248
- avg_metrics = df.groupby(["model", "db_category"])['avg_metric'].mean().reset_index()
1249
- avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
1250
-
1251
- # Plot orizzontale con modello sull'asse Y
1252
- fig = px.bar(
1253
- avg_metrics,
1254
- x='avg_metric',
1255
- y='model',
1256
- color='db_category', # categoria come colore
1257
- text='text_label',
1258
- barmode='group',
1259
- orientation='h',
1260
- color_discrete_map=DB_CATEGORY_COLORS, # devi avere questo dict come MODEL_COLORS
1261
- title='Average metric per model and db_category πŸ“Š',
1262
- labels={'avg_metric': 'AVG Metric', 'model': 'Model'},
1263
- template='plotly_dark'
1264
- )
1265
-
1266
- fig.update_traces(textposition='outside', textfont_size=10)
1267
- fig.update_layout(
1268
- margin=dict(t=80),
1269
- yaxis=dict(title=''),
1270
- xaxis=dict(title='AVG Metrics'),
1271
- legend_title='DB Name',
1272
- height=600 # puoi aumentare se ci sono tanti modelli
1273
- )
1274
-
1275
- return gr.Plot(fig, visible=True)
1276
- """
1277
 
1278
  def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
1279
  df = load_data_csv_es()
@@ -1289,7 +1322,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1289
  df = df[df['db_category'].isin(target_cats)]
1290
  df = df[df['model'].isin(selected_models)]
1291
 
1292
- df = normalize_valid_efficiency_score(df)
1293
  df = calculate_average_metrics(df, qatch_metrics)
1294
 
1295
  # Calcola la media per db_category e modello
@@ -1410,14 +1443,14 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1410
 
1411
  # RADAR OR BAR CHART BASED ON CATEGORY COUNT
1412
  def plot_radar(df, selected_models, selected_metrics, selected_categories):
1413
- if "external" in selected_metrics:
1414
- selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
1415
  else:
1416
  selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
1417
 
1418
  # Filtro modelli e normalizzazione
1419
  df = df[df['model'].isin(selected_models)]
1420
- df = normalize_valid_efficiency_score(df)
1421
  df = calculate_average_metrics(df, selected_metrics)
1422
 
1423
  avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
@@ -1574,13 +1607,13 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1574
 
1575
  # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
1576
  def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
1577
- if "external" in selected_metrics:
1578
- selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
1579
  else:
1580
  selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
1581
 
1582
  df = df[df['model'].isin(selected_models)]
1583
- df = normalize_valid_efficiency_score(df)
1584
  df = calculate_average_metrics(df, selected_metrics)
1585
 
1586
  if isinstance(selected_category, str):
@@ -1743,6 +1776,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1743
 
1744
  # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
1745
  def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
 
1746
  if selected_models == "All":
1747
  selected_models = models
1748
  else:
@@ -1757,15 +1791,25 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1757
  df = df[df['test_category'].isin(selected_categories)]
1758
 
1759
  if "external" in selected_metrics:
1760
- selected_metrics = ["execution_accuracy", "valid_efficiency_score"]
1761
  else:
1762
  selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
1763
 
1764
- df = normalize_valid_efficiency_score(df)
1765
  df = calculate_average_metrics(df, selected_metrics)
1766
-
1767
- worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
1768
 
 
 
 
 
 
 
 
 
 
 
 
 
1769
  worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
1770
 
1771
  worst_cases_top_3 = worst_cases_df.head(3)
@@ -1778,14 +1822,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1778
  medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]
1779
 
1780
  for i, row in worst_cases_top_3.iterrows():
1781
- entry = (
1782
- f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
1783
- f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
1784
- f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
1785
- f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
1786
- )
 
 
 
 
 
 
 
 
 
 
1787
 
1788
- worst_str.append(entry)
1789
 
1790
  raw_answer = (
1791
  f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
@@ -1793,7 +1847,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1793
  )
1794
 
1795
  answer_str.append(raw_answer)
1796
-
1797
  return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
1798
 
1799
  def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
@@ -1803,7 +1857,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1803
  # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
1804
  def plot_cumulative_flow(df, selected_models, max_points):
1805
  df = df[df['model'].isin(selected_models)]
1806
- df = normalize_valid_efficiency_score(df)
1807
 
1808
  fig = go.Figure()
1809
 
@@ -1937,10 +1991,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1937
 
1938
  external_metrics_dict = {
1939
  "Execution Accuracy": "execution_accuracy",
1940
- "Valid Efficiency Score": "valid_efficiency_score"
1941
  }
1942
 
1943
- external_metric = ["execution_accuracy", "valid_efficiency_score"]
1944
  last_valid_external_metric_selection = external_metric.copy()
1945
  def enforce_external_metric_selection(selected):
1946
  global last_valid_external_metric_selection
@@ -1987,10 +2041,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
1987
 
1988
  all_model_as_dic = {cat: [f"{cat}"] for cat in models}
1989
  all_model_as_dic["All"] = models
1990
-
1991
- #with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
1992
-
1993
-
1994
 
1995
  ###########################
1996
  # VISUALIZATION SECTION #
@@ -2029,7 +2079,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
2029
  <span
2030
  title="External metric info:
2031
  Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
2032
- Valid Efficiency Score: Evaluates the efficiency of a query by combining execution time and correctness. It rewards queries that are both accurate and fast."
2033
  style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
2034
  >External metric info ℹ️</span>
2035
  </div>
@@ -2304,6 +2354,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
2304
  reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
2305
  reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
2306
  reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
2307
-
2308
 
2309
  interface.launch(share = True)
 
12
  from qatch.connectors.sqlite_connector import SqliteConnector
13
  from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
14
  from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
15
+ import qatch.evaluate_dataset.orchestrator_evaluator as eva
16
  from prediction import ModelPrediction
17
  import utils_get_db_tables_info
18
  import utilities as us
 
32
  #pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
33
  pnp_path = "concatenated_output.csv"
34
  PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
 
35
  js_func = """
36
  function refresh() {
37
  const url = new URL(window.location);
 
42
  }
43
  }
44
  """
45
+ reset_flag = False
46
+ flag_TQA = False
47
 
48
  with open('style.css', 'r') as file:
49
  css = file.read()
 
66
  ### ➀ **Non-Proprietary**
67
  ### &ensp;&ensp;&ensp; β‡’ Spider 1.0 πŸ•·οΈ"""
68
  prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
69
+ prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as <answer> </answer>.\n Question \n {question}\n Database Schema\n {db_schema}\n"
70
+
71
 
72
  input_data = {
73
  'input_method': "",
 
96
  #change path
97
  input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
98
  input_data["data"] = us.load_data(file, input_data["db_name"])
99
+
100
  df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
101
  if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
102
  table2primary_key = {}
 
321
 
322
  # Model selection button (initially disabled)
323
  open_model_selection = gr.Button("Choose your models", interactive=False)
 
324
  def update_table_list(data):
325
  """Dynamically updates the list of available tables and excluded ones."""
326
  if isinstance(data, dict) and data:
 
461
  default_checkbox
462
  ]
463
  )
 
464
  reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
 
465
 
466
  ####################################
467
  # MODEL SELECTION PART #
 
507
  # Function to get selected models
508
  def get_selected_models(*model_selections):
509
  selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
 
510
  input_data['models'] = selected_models
511
  button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
512
+ return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state), gr.update(interactive=button_state)
513
 
514
  # Add the Textbox to the interface
515
  prompt = gr.TextArea(
 
517
  placeholder=prompt_default,
518
  elem_id="custom-textarea"
519
  )
520
+
521
  warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
522
 
523
  # Submit button (initially disabled)
524
+ with gr.Row():
525
+ submit_models_button = gr.Button("Submit Models for NL2SQL task", interactive=False)
526
+ submit_models_button_tqa = gr.Button("Submit Models for TQA task", interactive=False)
527
 
528
  def check_prompt(prompt):
529
  #TODO
530
  missing_elements = []
531
  if(prompt==""):
532
+ input_data["prompt"] = prompt_default
533
  button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
534
  else:
535
  input_data["prompt"]=prompt
 
546
  ), gr.update(interactive=button_state)
547
  return gr.update(visible=False), gr.update(interactive=button_state)
548
 
549
+ prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, submit_models_button_tqa])
550
  # Link checkboxes to selection events
551
  for checkbox in model_checkboxes:
552
  checkbox.change(
553
  fn=get_selected_models,
554
  inputs=model_checkboxes,
555
+ outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
556
  )
557
  prompt.change(
558
  fn=get_selected_models,
559
  inputs=model_checkboxes,
560
+ outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
561
  )
562
 
563
  submit_models_button.click(
 
566
  outputs=[selected_models_output, select_model_acc, qatch_acc]
567
  )
568
 
569
+ submit_models_button_tqa.click(
570
+ fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
571
+ inputs=model_checkboxes,
572
+ outputs=[selected_models_output, select_model_acc, qatch_acc]
573
+ )
574
+ def change_flag():
575
+ global flag_TQA
576
+ flag_TQA = True
577
+
578
+ submit_models_button_tqa.click(fn = change_flag, inputs=[], outputs=[])
579
+
580
  def enable_disable(enable):
581
  return (
582
  *[gr.update(interactive=enable) for _ in model_checkboxes],
 
587
  gr.update(interactive=enable),
588
  gr.update(interactive=enable),
589
  *[gr.update(interactive=enable) for _ in table_outputs],
590
+ gr.update(interactive=enable),
591
  gr.update(interactive=enable)
592
  )
593
 
 
605
  default_checkbox,
606
  table_selector,
607
  *table_outputs,
608
+ open_model_selection,
609
+ submit_models_button_tqa
610
+ ]
611
+ )
612
+ submit_models_button_tqa.click(
613
+ fn=enable_disable,
614
+ inputs=[gr.State(False)],
615
+ outputs=[
616
+ *model_checkboxes,
617
+ submit_models_button,
618
+ preview_output,
619
+ submit_button,
620
+ file_input,
621
+ default_checkbox,
622
+ table_selector,
623
+ *table_outputs,
624
+ open_model_selection,
625
+ submit_models_button_tqa
626
  ]
627
  )
628
 
 
640
  default_checkbox,
641
  table_selector,
642
  *table_outputs,
643
+ open_model_selection,
644
+ submit_models_button_tqa
645
  ]
646
  )
647
 
 
692
  {mirrored_symbols}
693
  </div>
694
  """
695
+
696
+ def qatch_flow_nl_sql():
697
  global reset_flag
698
+ global flag_TQA
699
  predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
700
  metrics_conc = pd.DataFrame()
701
  columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
 
725
  </div>
726
  """
727
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
728
+
729
  prediction = row['predicted_sql']
730
 
731
  display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
 
733
  <div style='font-size: 3rem'>➑️</div>
734
  <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
735
  </div>
736
+ """
737
+
738
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
739
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
740
  metrics_conc = target_df
741
+ if 'valid_efficency_score' not in metrics_conc.columns:
742
+ metrics_conc['valid_efficency_score'] = metrics_conc['VES']
743
  eval_text = generate_eval_text("End evaluation")
744
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
745
+
746
  else:
 
747
  orchestrator_generator = OrchestratorGenerator()
748
+ target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=input_data['data']['selected_tables'])
749
+
750
+ #create target_df[target_answer]
751
+ if flag_TQA :
752
+ if (input_data["prompt"] == prompt_default):
753
+ input_data["prompt"] = prompt_default_tqa
754
+ target_df = us.extract_answer(target_df)
755
 
756
  predictor = ModelPrediction()
757
  reset_flag = False
 
772
  </div>
773
  """
774
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
775
+ #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
776
+ model_to_send = None if not flag_TQA else model
777
+
778
+
779
+ db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
780
+ db_id = input_data["db_name"],
781
+ base_path = input_data["data_path"],
782
+ normalize=False,
783
+ sql=row["query"],
784
+ get_insert_into=True,
785
+ model = model_to_send,
786
+ prompt = input_data["prompt"].format(question=question, db_schema=""),
787
  )
788
 
789
  #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
 
791
  #PREDICTION SQL
792
 
793
  # TODO add button for QA or SP and pass to .make_prediction parameter TASK
794
+ if flag_TQA: task="QA"
795
+ else: task="SP"
796
+ start_time = time.time()
797
  response = predictor.make_prediction(
798
  question=question,
799
+ db_schema=db_schema_text,
800
  model_name=model,
801
  prompt=f"{prompt_to_send}",
802
+ task=task
803
  )
804
  prediction = response['response_parsed']
805
  price = response['cost']
806
  answer = response['response']
807
 
808
  end_time = time.time()
809
+ if flag_TQA:
810
+ task_string = "Answer"
811
+ else:
812
+ task_string = "SQL"
813
+
814
+ display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted {task_string}:</div>
815
  <div style='display: flex; align-items: center;'>
816
  <div style='font-size: 3rem'>➑️</div>
817
  <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
 
826
  'query': row["query"],
827
  'db_path': input_data["data_path"],
828
  'price':price,
829
+ 'answer': answer,
830
  'number_question':count,
831
+ 'target_answer' : row["target_answer"] if flag_TQA else None,
832
+
833
  }]).dropna(how="all") # Remove only completely empty rows
834
  count=count+1
835
  # TODO: use a for loop
836
+ if (flag_TQA) :
837
+ new_row['predicted_answer'] = prediction
838
  for col in target_df.columns:
839
  if col not in new_row.columns:
840
  new_row[col] = row[col]
 
841
  # Update model's prediction dataframe incrementally
842
  if not new_row.empty:
843
  predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
844
 
845
  # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
846
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
 
847
  yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
848
  # END
849
  eval_text = generate_eval_text("Evaluation")
850
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
851
+
852
  evaluator = OrchestratorEvaluator()
853
+
854
  for model in input_data["models"]:
855
+ if not flag_TQA:
856
+ metrics_df_model = evaluator.evaluate_df(
857
+ df=predictions_dict[model],
858
+ target_col_name="query",
859
+ prediction_col_name="predicted_sql",
860
+ db_path_name="db_path"
861
+ )
862
+ else:
863
+ metrics_df_model = us.evaluate_answer(predictions_dict[model])
864
  metrics_df_model['model'] = model
865
  metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
866
+
867
+ if 'valid_efficency_score' not in metrics_conc.columns and 'VES' in metrics_conc.columns:
868
+ metrics_conc['valid_efficency_score'] = metrics_conc['VES']
869
+
870
  eval_text = generate_eval_text("End evaluation")
871
  yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
872
 
 
902
  gr.Markdown(f"**Results for {model}**")
903
  tab_dict[model] = tab
904
  dataframe_per_model[model] = gr.DataFrame()
905
+ #TODO download metrics per model
906
  # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
907
 
908
  evaluation_loading = gr.Markdown()
 
915
  inputs=[],
916
  outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
917
  )
918
+ submit_models_button_tqa.click(
919
+ change_tab,
920
+ inputs=[],
921
+ outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
922
+ )
923
 
924
  selected_models_display = gr.JSON(label="Final input data", visible=False)
925
  metrics_df = gr.DataFrame(visible=False)
926
  metrics_df_out = gr.DataFrame(visible=False)
927
 
928
  submit_models_button.click(
929
+ fn=qatch_flow_nl_sql,
930
+ inputs=[],
931
+ outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
932
+ )
933
+
934
+ submit_models_button_tqa.click(
935
+ fn=qatch_flow_nl_sql,
936
  inputs=[],
937
  outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
938
  )
 
941
  fn=lambda: gr.update(value=input_data),
942
  outputs=[selected_models_display]
943
  )
944
+ submit_models_button_tqa.click(
945
+ fn=lambda: gr.update(value=input_data),
946
+ outputs=[selected_models_display]
947
+ )
948
 
949
  # Works for METRICS
950
  metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
 
967
  fn=lambda: gr.update(visible=False),
968
  outputs=[download_metrics]
969
  )
970
+ submit_models_button_tqa.click(
971
+ fn=lambda: gr.update(visible=False),
972
+ outputs=[download_metrics]
973
+ )
974
 
975
  def refresh():
976
  global reset_flag
977
+ global flag_TQA
978
  reset_flag = True
979
+ flag_TQA = False
980
 
981
  reset_data = gr.Button("Back to upload data section", interactive=True)
982
 
 
1002
  default_checkbox,
1003
  table_selector,
1004
  *table_outputs,
1005
+ open_model_selection,
1006
+ submit_models_button_tqa
1007
  ]
1008
  )
1009
+
1010
+
1011
  ##########################################
1012
  # METRICS VISUALIZATION SECTION #
1013
  ##########################################
 
1022
  ####################################
1023
 
1024
  def load_data_csv_es():
1025
+
1026
  if input_data["input_method"]=="default":
1027
+ global flag_TQA
1028
  df = pd.read_csv(pnp_path)
1029
  df = df[df['model'].isin(input_data["models"])]
1030
  df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
 
1035
  df['model'] = df['model'].replace('llama-70', 'Llama-70B')
1036
  df['model'] = df['model'].replace('llama-8', 'Llama-8B')
1037
  df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
1038
+ if (flag_TQA) : flag_TQA = False #TODO delete after make pred
1039
  return df
1040
  return metrics_df_out
1041
 
 
1078
 
1079
  DB_CATEGORY_COLORS = generate_db_category_colors()
1080
 
1081
+ def normalize_valid_efficency_score(df):
1082
+ df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0)
1083
+ df['valid_efficency_score'] = df['valid_efficency_score'].astype(int)
1084
+ min_val = df['valid_efficency_score'].min()
1085
+ max_val = df['valid_efficency_score'].max()
 
 
1086
 
1087
+ if min_val == max_val :
1088
+ # All values are equal, so for avoid division by zero, we set the score to 1/0
1089
+ if min_val == None:
1090
+ df['valid_efficency_score'] = 0
1091
+ else:
1092
+ df['valid_efficency_score'] = 1.0
1093
  else:
1094
+ df['valid_efficency_score'] = (
1095
+ df['valid_efficency_score'] - min_val
1096
  ) / (max_val - min_val)
1097
 
1098
  return df
 
1105
  # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
1106
  def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
1107
  df = df[df['model'].isin(selected_models)]
1108
+ df = normalize_valid_efficency_score(df)
1109
 
1110
  # Mappatura nomi leggibili -> tecnici
1111
  qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
 
1222
  selected_models = [selected_models]
1223
 
1224
  df = df[df['model'].isin(selected_models)]
1225
+ df = normalize_valid_efficency_score(df)
1226
 
1227
  # Converti nomi leggibili -> tecnici
1228
  qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
 
1307
  )
1308
 
1309
  return gr.Plot(fig, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1310
 
1311
  def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
1312
  df = load_data_csv_es()
 
1322
  df = df[df['db_category'].isin(target_cats)]
1323
  df = df[df['model'].isin(selected_models)]
1324
 
1325
+ df = normalize_valid_efficency_score(df)
1326
  df = calculate_average_metrics(df, qatch_metrics)
1327
 
1328
  # Calcola la media per db_category e modello
 
1443
 
1444
  # RADAR OR BAR CHART BASED ON CATEGORY COUNT
1445
  def plot_radar(df, selected_models, selected_metrics, selected_categories):
1446
+ if "External" in selected_metrics:
1447
+ selected_metrics = ["execution_accuracy", "valid_efficency_score"]
1448
  else:
1449
  selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
1450
 
1451
  # Filtro modelli e normalizzazione
1452
  df = df[df['model'].isin(selected_models)]
1453
+ df = normalize_valid_efficency_score(df)
1454
  df = calculate_average_metrics(df, selected_metrics)
1455
 
1456
  avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
 
1607
 
1608
  # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
1609
  def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
1610
+ if "External" in selected_metrics:
1611
+ selected_metrics = ["execution_accuracy", "valid_efficency_score"]
1612
  else:
1613
  selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
1614
 
1615
  df = df[df['model'].isin(selected_models)]
1616
+ df = normalize_valid_efficency_score(df)
1617
  df = calculate_average_metrics(df, selected_metrics)
1618
 
1619
  if isinstance(selected_category, str):
 
1776
 
1777
  # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
1778
  def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
1779
+ global flag_TQA
1780
  if selected_models == "All":
1781
  selected_models = models
1782
  else:
 
1791
  df = df[df['test_category'].isin(selected_categories)]
1792
 
1793
  if "external" in selected_metrics:
1794
+ selected_metrics = ["execution_accuracy", "valid_efficency_score"]
1795
  else:
1796
  selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
1797
 
1798
+ df = normalize_valid_efficency_score(df)
1799
  df = calculate_average_metrics(df, selected_metrics)
 
 
1800
 
1801
+ if flag_TQA:
1802
+ df["target_answer"] = df["target_answer"].apply(
1803
+ lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
1804
+ )
1805
+ df["predicted_answer"] = df["predicted_answer"].apply(
1806
+ lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
1807
+ )
1808
+
1809
+ worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
1810
+ else:
1811
+ worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
1812
+
1813
  worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
1814
 
1815
  worst_cases_top_3 = worst_cases_df.head(3)
 
1822
  medals = ["πŸ₯‡", "πŸ₯ˆ", "πŸ₯‰"]
1823
 
1824
  for i, row in worst_cases_top_3.iterrows():
1825
+ if flag_TQA:
1826
+ entry = (
1827
+ f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
1828
+ f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
1829
+ f"<span style='font-size:16px;'>- <b>Original Answer:</b> `{row['target_answer']}`</span> \n"
1830
+ f"<span style='font-size:16px;'>- <b>Predicted Answer:</b> `{row['predicted_answer']}`</span> \n\n"
1831
+ )
1832
+
1833
+ worst_str.append(entry)
1834
+ else:
1835
+ entry = (
1836
+ f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
1837
+ f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
1838
+ f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
1839
+ f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
1840
+ )
1841
 
1842
+ worst_str.append(entry)
1843
 
1844
  raw_answer = (
1845
  f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
 
1847
  )
1848
 
1849
  answer_str.append(raw_answer)
1850
+
1851
  return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
1852
 
1853
  def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
 
1857
  # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
1858
  def plot_cumulative_flow(df, selected_models, max_points):
1859
  df = df[df['model'].isin(selected_models)]
1860
+ df = normalize_valid_efficency_score(df)
1861
 
1862
  fig = go.Figure()
1863
 
 
1991
 
1992
  external_metrics_dict = {
1993
  "Execution Accuracy": "execution_accuracy",
1994
+ "Valid Efficency Score": "valid_efficency_score"
1995
  }
1996
 
1997
+ external_metric = ["execution_accuracy", "valid_efficency_score"]
1998
  last_valid_external_metric_selection = external_metric.copy()
1999
  def enforce_external_metric_selection(selected):
2000
  global last_valid_external_metric_selection
 
2041
 
2042
  all_model_as_dic = {cat: [f"{cat}"] for cat in models}
2043
  all_model_as_dic["All"] = models
 
 
 
 
2044
 
2045
  ###########################
2046
  # VISUALIZATION SECTION #
 
2079
  <span
2080
  title="External metric info:
2081
  Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
2082
+ Valid Efficency Score: Evaluates the efficency of a query by combining execution time and correctness. It rewards queries that are both accurate and fast."
2083
  style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
2084
  >External metric info ℹ️</span>
2085
  </div>
 
2354
  reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
2355
  reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
2356
  reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
2357
+
2358
 
2359
  interface.launch(share = True)
concatenated_output.csv CHANGED
@@ -1,4 +1,4 @@
1
- cell_precision,sql_tag,tuple_cardinality,answer,predicted_sql,db_category,tuple_constraint,VES,number_question,valid_efficiency_score,tbl_name,tuple_order,time,price,question,model,cell_recall,db_path,execution_accuracy,test_category,query
2
  1.0,DISTINCT-SINGLE,1.0,"```sql
3
  SELECT DISTINCT WAREHOUSE_LOAD_DATE
4
  FROM FAC_BUILDING_ADDRESS;
 
1
+ cell_precision,sql_tag,tuple_cardinality,answer,predicted_sql,db_category,tuple_constraint,VES,number_question,valid_efficency_score,tbl_name,tuple_order,time,price,question,model,cell_recall,db_path,execution_accuracy,test_category,query
2
  1.0,DISTINCT-SINGLE,1.0,"```sql
3
  SELECT DISTINCT WAREHOUSE_LOAD_DATE
4
  FROM FAC_BUILDING_ADDRESS;
utilities.py CHANGED
@@ -6,6 +6,11 @@ import sqlite3
6
  import gradio as gr
7
  import os
8
  from qatch.connectors.sqlite_connector import SqliteConnector
 
 
 
 
 
9
  def extract_tables(file_path):
10
  conn = sqlite3.connect(file_path)
11
  cursor = conn.cursor()
@@ -26,7 +31,7 @@ def extract_dataframes(file_path):
26
  return dfs
27
 
28
  def carica_sqlite(file_path, db_id):
29
- data_output = {'data_frames': extract_dataframes(file_path),'db':SqliteConnector(relative_db_path=file_path, db_name=db_id)}
30
  return data_output
31
 
32
  # Funzione per leggere un file CSV
@@ -113,3 +118,94 @@ def generate_some_samples(file_path, tbl_name):
113
  def load_tables_dict_from_pkl(file_path):
114
  with open(file_path, 'rb') as f:
115
  return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gradio as gr
7
  import os
8
  from qatch.connectors.sqlite_connector import SqliteConnector
9
+ from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
10
+ import qatch.evaluate_dataset.orchestrator_evaluator as eva
11
+ #import tiktoken
12
+ from transformers import AutoTokenizer
13
+
14
  def extract_tables(file_path):
15
  conn = sqlite3.connect(file_path)
16
  cursor = conn.cursor()
 
31
  return dfs
32
 
33
  def carica_sqlite(file_path, db_id):
34
+ data_output = {'data_frames': extract_dataframes(file_path),'db': SqliteConnector(relative_db_path=file_path, db_name=db_id)}
35
  return data_output
36
 
37
  # Funzione per leggere un file CSV
 
118
  def load_tables_dict_from_pkl(file_path):
119
  with open(file_path, 'rb') as f:
120
  return pickle.load(f)
121
+
122
+ def extract_tables_dict(pnp_path):
123
+ return load_tables_dict_from_pkl('tables_dict_beaver.pkl')
124
+ tables_dict = {}
125
+ with open(pnp_path, mode='r', encoding='utf-8') as file:
126
+ reader = csv.DictReader(file)
127
+ tbl_db_pairs = set() # Use a set to avoid duplicates
128
+ for row in reader:
129
+ tbl_name = row.get("tbl_name")
130
+ db_path = row.get("db_path")
131
+ if tbl_name and db_path:
132
+ tbl_db_pairs.add((tbl_name, db_path)) # Add the pair to the set
133
+ for tbl_name, db_path in list(tbl_db_pairs):
134
+ if tbl_name and db_path:
135
+ connector = sqlite3.connect(db_path)
136
+ query = f"SELECT * FROM {tbl_name} LIMIT 5"
137
+ try:
138
+ df = pd.read_sql_query(query, connector)
139
+ tables_dict[tbl_name] = df
140
+ except Exception as e:
141
+ tables_dict[tbl_name] = pd.DataFrame({"Error": [str(e)]}) # DataFrame con messaggio di errore
142
+ #with open('tables_dict_beaver.pkl', 'wb') as f:
143
+ # pickle.dump(tables_dict, f)
144
+ return tables_dict
145
+
146
+
147
+ def extract_answer(df):
148
+ if "query" not in df.columns or "db_path" not in df.columns:
149
+ raise ValueError("The DataFrame must contain 'query' and 'data_path' columns.")
150
+
151
+ answers = []
152
+ for _, row in df.iterrows():
153
+ query = row["query"]
154
+ db_path = row["db_path"]
155
+ try:
156
+ conn = SqliteConnector(relative_db_path = db_path , db_name= "db")
157
+ answer = eva._utils_run_query_if_str(query, conn)
158
+ answers.append(answer)
159
+ except Exception as e:
160
+ answers.append(f"Error: {e}")
161
+
162
+ df["target_answer"] = answers
163
+ return df
164
+
165
+ evaluator = {
166
+ "cell_precision": CellPrecision(),
167
+ "cell_recall": CellRecall(),
168
+ "tuple_cardinality": TupleCardinality(),
169
+ "tuple_order": TupleOrder(),
170
+ "tuple_constraint": TupleConstraint(),
171
+ "execution_accuracy": ExecutionAccuracy(),
172
+ "valid_efficency_score": ValidEfficiencyScore()
173
+ }
174
+
175
+ def evaluate_answer(df):
176
+ for metric_name, metric in evaluator.items():
177
+ results = []
178
+ for _, row in df.iterrows():
179
+ target = row["target_answer"]
180
+ predicted = row["predicted_answer"]
181
+ try:
182
+ result = metric.run_metric(target = target, prediction = predicted)
183
+ except Exception as e:
184
+ result = None
185
+ results.append(result)
186
+ df[metric_name] = results
187
+ return df
188
+
189
+ models = [
190
+ "gpt-4o-mini",
191
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
192
+ ]
193
+
194
+ def crop_entries_per_token(entries_list, model, prompt: str | None = None):
195
+ #open_ai_models = ["gpt-3.5", "gpt-4o-mini"]
196
+ dimension = 2048
197
+ #enties_string = [", ".join(map(str, entry)) for entry in entries_list]
198
+ if prompt:
199
+ entries_string = prompt.join(entries_list)
200
+ else:
201
+ entries_string = " ".join(entries_list)
202
+ #if model in ["deepseek-ai/DeepSeek-R1-Distill-Llama-70B" ,"gpt-4o-mini" ] :
203
+ #tokenizer = tiktoken.encoding_for_model("gpt-4o-mini")
204
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B")
205
+
206
+ tokens = tokenizer.encode(entries_string)
207
+ number_of_tokens = len(tokens)
208
+ if number_of_tokens > dimension and len(entries_list) > 4:
209
+ entries_list = entries_list[:round(len(entries_list)/2)]
210
+ entries_list = crop_entries_per_token(entries_list, model)
211
+ return entries_list
utils_get_db_tables_info.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
  import sqlite3
3
  import re
4
-
5
 
6
  def utils_extract_db_schema_as_string(
7
- db_id, base_path, normalize=False, sql: str | None = None, get_insert_into: bool = False
8
  ):
9
  """
10
  Extracts the full schema of an SQLite database into a single string.
@@ -19,7 +19,7 @@ def utils_extract_db_schema_as_string(
19
  cursor = connection.cursor()
20
 
21
  # Get the schema entries based on the provided SQL query
22
- schema_entries = _get_schema_entries(cursor, sql, get_insert_into)
23
 
24
  # Combine all schema definitions into a single string
25
  schema_string = _combine_schema_entries(schema_entries, normalize)
@@ -28,7 +28,7 @@ def utils_extract_db_schema_as_string(
28
 
29
 
30
 
31
- def _get_schema_entries(cursor, sql=None, get_insert_into=False):
32
  """
33
  Retrieves schema entries and optionally data entries from the SQLite database.
34
 
@@ -62,11 +62,17 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False):
62
  column_names = [description[0] for description in cursor.description]
63
 
64
  # Generate INSERT INTO statements for each row
65
- # TODO now hardcoded to first 3
66
- for row in rows[:3]:
 
 
 
 
67
  values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
68
  insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
69
  entries.append(insert_stmt)
 
 
70
 
71
  return entries
72
 
 
1
  import os
2
  import sqlite3
3
  import re
4
+ import utilities as us
5
 
6
  def utils_extract_db_schema_as_string(
7
+ db_id, base_path, model : str | None = None , normalize=False, sql: str | None = None, get_insert_into: bool = False, prompt : str | None = None
8
  ):
9
  """
10
  Extracts the full schema of an SQLite database into a single string.
 
19
  cursor = connection.cursor()
20
 
21
  # Get the schema entries based on the provided SQL query
22
+ schema_entries = _get_schema_entries(cursor, sql, get_insert_into, model, prompt)
23
 
24
  # Combine all schema definitions into a single string
25
  schema_string = _combine_schema_entries(schema_entries, normalize)
 
28
 
29
 
30
 
31
+ def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | None = None, prompt : str | None = None):
32
  """
33
  Retrieves schema entries and optionally data entries from the SQLite database.
34
 
 
62
  column_names = [description[0] for description in cursor.description]
63
 
64
  # Generate INSERT INTO statements for each row
65
+ if model==None :
66
+ max_len=3
67
+ else:
68
+ max_len = len(rows)
69
+
70
+ for row in rows[:max_len]:
71
  values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
72
  insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
73
  entries.append(insert_stmt)
74
+
75
+ if model != None : entries = us.crop_entries_per_token(entries, model, prompt)
76
 
77
  return entries
78