Spaces:
Sleeping
Sleeping
Add TQA task
#21
by
franceth
- opened
- app.py +196 -146
- concatenated_output.csv +1 -1
- utilities.py +97 -1
- utils_get_db_tables_info.py +12 -6
app.py
CHANGED
@@ -12,6 +12,7 @@ import plotly.colors as pc
|
|
12 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
13 |
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
|
14 |
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
|
|
|
15 |
from prediction import ModelPrediction
|
16 |
import utils_get_db_tables_info
|
17 |
import utilities as us
|
@@ -31,7 +32,6 @@ import utilities as us
|
|
31 |
#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
|
32 |
pnp_path = "concatenated_output.csv"
|
33 |
PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
|
34 |
-
|
35 |
js_func = """
|
36 |
function refresh() {
|
37 |
const url = new URL(window.location);
|
@@ -42,7 +42,8 @@ function refresh() {
|
|
42 |
}
|
43 |
}
|
44 |
"""
|
45 |
-
reset_flag=False
|
|
|
46 |
|
47 |
with open('style.css', 'r') as file:
|
48 |
css = file.read()
|
@@ -65,6 +66,8 @@ description = """## π Comparison of Proprietary and Non-Proprietary Databases
|
|
65 |
### β€ **Non-Proprietary**
|
66 |
###     β Spider 1.0 π·οΈ"""
|
67 |
prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
|
|
|
|
|
68 |
|
69 |
input_data = {
|
70 |
'input_method': "",
|
@@ -93,6 +96,7 @@ def load_data(file, path, use_default):
|
|
93 |
#change path
|
94 |
input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
|
95 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
|
|
96 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
97 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
98 |
table2primary_key = {}
|
@@ -317,7 +321,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
317 |
|
318 |
# Model selection button (initially disabled)
|
319 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
320 |
-
|
321 |
def update_table_list(data):
|
322 |
"""Dynamically updates the list of available tables and excluded ones."""
|
323 |
if isinstance(data, dict) and data:
|
@@ -458,9 +461,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
458 |
default_checkbox
|
459 |
]
|
460 |
)
|
461 |
-
|
462 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
463 |
-
|
464 |
|
465 |
####################################
|
466 |
# MODEL SELECTION PART #
|
@@ -506,10 +507,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
506 |
# Function to get selected models
|
507 |
def get_selected_models(*model_selections):
|
508 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
509 |
-
|
510 |
input_data['models'] = selected_models
|
511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
512 |
-
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
|
513 |
|
514 |
# Add the Textbox to the interface
|
515 |
prompt = gr.TextArea(
|
@@ -517,17 +517,19 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
517 |
placeholder=prompt_default,
|
518 |
elem_id="custom-textarea"
|
519 |
)
|
|
|
520 |
warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
|
521 |
|
522 |
# Submit button (initially disabled)
|
523 |
-
|
524 |
-
|
|
|
525 |
|
526 |
def check_prompt(prompt):
|
527 |
#TODO
|
528 |
missing_elements = []
|
529 |
if(prompt==""):
|
530 |
-
input_data["prompt"]=prompt_default
|
531 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
532 |
else:
|
533 |
input_data["prompt"]=prompt
|
@@ -544,18 +546,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
544 |
), gr.update(interactive=button_state)
|
545 |
return gr.update(visible=False), gr.update(interactive=button_state)
|
546 |
|
547 |
-
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
|
548 |
# Link checkboxes to selection events
|
549 |
for checkbox in model_checkboxes:
|
550 |
checkbox.change(
|
551 |
fn=get_selected_models,
|
552 |
inputs=model_checkboxes,
|
553 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
554 |
)
|
555 |
prompt.change(
|
556 |
fn=get_selected_models,
|
557 |
inputs=model_checkboxes,
|
558 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
559 |
)
|
560 |
|
561 |
submit_models_button.click(
|
@@ -564,6 +566,17 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
564 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
565 |
)
|
566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
def enable_disable(enable):
|
568 |
return (
|
569 |
*[gr.update(interactive=enable) for _ in model_checkboxes],
|
@@ -574,6 +587,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
574 |
gr.update(interactive=enable),
|
575 |
gr.update(interactive=enable),
|
576 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
|
|
577 |
gr.update(interactive=enable)
|
578 |
)
|
579 |
|
@@ -591,7 +605,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
591 |
default_checkbox,
|
592 |
table_selector,
|
593 |
*table_outputs,
|
594 |
-
open_model_selection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
]
|
596 |
)
|
597 |
|
@@ -609,7 +640,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
609 |
default_checkbox,
|
610 |
table_selector,
|
611 |
*table_outputs,
|
612 |
-
open_model_selection
|
|
|
613 |
]
|
614 |
)
|
615 |
|
@@ -660,9 +692,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
660 |
{mirrored_symbols}
|
661 |
</div>
|
662 |
"""
|
663 |
-
|
664 |
-
|
665 |
global reset_flag
|
|
|
666 |
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
|
667 |
metrics_conc = pd.DataFrame()
|
668 |
columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
|
@@ -692,7 +725,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
692 |
</div>
|
693 |
"""
|
694 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
695 |
-
|
696 |
prediction = row['predicted_sql']
|
697 |
|
698 |
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
|
@@ -700,22 +733,25 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
700 |
<div style='font-size: 3rem'>β‘οΈ</div>
|
701 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
702 |
</div>
|
703 |
-
"""
|
|
|
704 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
705 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
706 |
metrics_conc = target_df
|
707 |
-
if '
|
708 |
-
metrics_conc['
|
709 |
eval_text = generate_eval_text("End evaluation")
|
710 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
|
|
711 |
else:
|
712 |
-
|
713 |
orchestrator_generator = OrchestratorGenerator()
|
714 |
-
|
715 |
-
|
716 |
-
#
|
717 |
-
|
718 |
-
|
|
|
|
|
719 |
|
720 |
predictor = ModelPrediction()
|
721 |
reset_flag = False
|
@@ -736,15 +772,18 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
736 |
</div>
|
737 |
"""
|
738 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
|
|
|
|
|
|
748 |
)
|
749 |
|
750 |
#prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
|
@@ -752,19 +791,27 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
752 |
#PREDICTION SQL
|
753 |
|
754 |
# TODO add button for QA or SP and pass to .make_prediction parameter TASK
|
|
|
|
|
|
|
755 |
response = predictor.make_prediction(
|
756 |
question=question,
|
757 |
-
db_schema=
|
758 |
model_name=model,
|
759 |
prompt=f"{prompt_to_send}",
|
760 |
-
task=
|
761 |
)
|
762 |
prediction = response['response_parsed']
|
763 |
price = response['cost']
|
764 |
answer = response['response']
|
765 |
|
766 |
end_time = time.time()
|
767 |
-
|
|
|
|
|
|
|
|
|
|
|
768 |
<div style='display: flex; align-items: center;'>
|
769 |
<div style='font-size: 3rem'>β‘οΈ</div>
|
770 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
@@ -779,40 +826,47 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
779 |
'query': row["query"],
|
780 |
'db_path': input_data["data_path"],
|
781 |
'price':price,
|
782 |
-
'answer':answer,
|
783 |
'number_question':count,
|
784 |
-
'
|
|
|
785 |
}]).dropna(how="all") # Remove only completely empty rows
|
786 |
count=count+1
|
787 |
# TODO: use a for loop
|
|
|
|
|
788 |
for col in target_df.columns:
|
789 |
if col not in new_row.columns:
|
790 |
new_row[col] = row[col]
|
791 |
-
|
792 |
# Update model's prediction dataframe incrementally
|
793 |
if not new_row.empty:
|
794 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
795 |
|
796 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
797 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
798 |
-
|
799 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
800 |
# END
|
801 |
eval_text = generate_eval_text("Evaluation")
|
802 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
|
|
803 |
evaluator = OrchestratorEvaluator()
|
|
|
804 |
for model in input_data["models"]:
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
|
|
|
|
|
|
811 |
metrics_df_model['model'] = model
|
812 |
metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
|
813 |
-
|
814 |
-
if '
|
815 |
-
metrics_conc['
|
|
|
816 |
eval_text = generate_eval_text("End evaluation")
|
817 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
818 |
|
@@ -848,6 +902,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
848 |
gr.Markdown(f"**Results for {model}**")
|
849 |
tab_dict[model] = tab
|
850 |
dataframe_per_model[model] = gr.DataFrame()
|
|
|
851 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
852 |
|
853 |
evaluation_loading = gr.Markdown()
|
@@ -860,13 +915,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
860 |
inputs=[],
|
861 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
862 |
)
|
|
|
|
|
|
|
|
|
|
|
863 |
|
864 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
865 |
metrics_df = gr.DataFrame(visible=False)
|
866 |
metrics_df_out = gr.DataFrame(visible=False)
|
867 |
|
868 |
submit_models_button.click(
|
869 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
870 |
inputs=[],
|
871 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
872 |
)
|
@@ -875,6 +941,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
875 |
fn=lambda: gr.update(value=input_data),
|
876 |
outputs=[selected_models_display]
|
877 |
)
|
|
|
|
|
|
|
|
|
878 |
|
879 |
# Works for METRICS
|
880 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
@@ -897,10 +967,16 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
897 |
fn=lambda: gr.update(visible=False),
|
898 |
outputs=[download_metrics]
|
899 |
)
|
|
|
|
|
|
|
|
|
900 |
|
901 |
def refresh():
|
902 |
global reset_flag
|
|
|
903 |
reset_flag = True
|
|
|
904 |
|
905 |
reset_data = gr.Button("Back to upload data section", interactive=True)
|
906 |
|
@@ -926,10 +1002,12 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
926 |
default_checkbox,
|
927 |
table_selector,
|
928 |
*table_outputs,
|
929 |
-
open_model_selection
|
|
|
930 |
]
|
931 |
)
|
932 |
-
|
|
|
933 |
##########################################
|
934 |
# METRICS VISUALIZATION SECTION #
|
935 |
##########################################
|
@@ -944,8 +1022,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
944 |
####################################
|
945 |
|
946 |
def load_data_csv_es():
|
947 |
-
|
948 |
if input_data["input_method"]=="default":
|
|
|
949 |
df = pd.read_csv(pnp_path)
|
950 |
df = df[df['model'].isin(input_data["models"])]
|
951 |
df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
|
@@ -956,6 +1035,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
956 |
df['model'] = df['model'].replace('llama-70', 'Llama-70B')
|
957 |
df['model'] = df['model'].replace('llama-8', 'Llama-8B')
|
958 |
df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
|
|
|
959 |
return df
|
960 |
return metrics_df_out
|
961 |
|
@@ -998,20 +1078,21 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
998 |
|
999 |
DB_CATEGORY_COLORS = generate_db_category_colors()
|
1000 |
|
1001 |
-
def
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
min_val = df['valid_efficiency_score'].min()
|
1007 |
-
max_val = df['valid_efficiency_score'].max()
|
1008 |
|
1009 |
-
if min_val == max_val:
|
1010 |
-
|
1011 |
-
|
|
|
|
|
|
|
1012 |
else:
|
1013 |
-
df['
|
1014 |
-
df['
|
1015 |
) / (max_val - min_val)
|
1016 |
|
1017 |
return df
|
@@ -1024,7 +1105,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1024 |
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
|
1025 |
def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
|
1026 |
df = df[df['model'].isin(selected_models)]
|
1027 |
-
df =
|
1028 |
|
1029 |
# Mappatura nomi leggibili -> tecnici
|
1030 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
@@ -1141,7 +1222,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1141 |
selected_models = [selected_models]
|
1142 |
|
1143 |
df = df[df['model'].isin(selected_models)]
|
1144 |
-
df =
|
1145 |
|
1146 |
# Converti nomi leggibili -> tecnici
|
1147 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
@@ -1226,54 +1307,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1226 |
)
|
1227 |
|
1228 |
return gr.Plot(fig, visible=True)
|
1229 |
-
|
1230 |
-
"""
|
1231 |
-
def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
|
1232 |
-
if selected_models == "All":
|
1233 |
-
selected_models = models
|
1234 |
-
else:
|
1235 |
-
selected_models = [selected_models]
|
1236 |
-
|
1237 |
-
df = df[df['model'].isin(selected_models)]
|
1238 |
-
df = normalize_valid_efficiency_score(df)
|
1239 |
-
|
1240 |
-
if radio_metric == "Qatch":
|
1241 |
-
selected_metrics = qatch_selected_metrics
|
1242 |
-
else:
|
1243 |
-
selected_metrics = external_selected_metric
|
1244 |
-
|
1245 |
-
df = calculate_average_metrics(df, selected_metrics)
|
1246 |
-
|
1247 |
-
# Raggruppamento per modello e categoria
|
1248 |
-
avg_metrics = df.groupby(["model", "db_category"])['avg_metric'].mean().reset_index()
|
1249 |
-
avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
|
1250 |
-
|
1251 |
-
# Plot orizzontale con modello sull'asse Y
|
1252 |
-
fig = px.bar(
|
1253 |
-
avg_metrics,
|
1254 |
-
x='avg_metric',
|
1255 |
-
y='model',
|
1256 |
-
color='db_category', # categoria come colore
|
1257 |
-
text='text_label',
|
1258 |
-
barmode='group',
|
1259 |
-
orientation='h',
|
1260 |
-
color_discrete_map=DB_CATEGORY_COLORS, # devi avere questo dict come MODEL_COLORS
|
1261 |
-
title='Average metric per model and db_category π',
|
1262 |
-
labels={'avg_metric': 'AVG Metric', 'model': 'Model'},
|
1263 |
-
template='plotly_dark'
|
1264 |
-
)
|
1265 |
-
|
1266 |
-
fig.update_traces(textposition='outside', textfont_size=10)
|
1267 |
-
fig.update_layout(
|
1268 |
-
margin=dict(t=80),
|
1269 |
-
yaxis=dict(title=''),
|
1270 |
-
xaxis=dict(title='AVG Metrics'),
|
1271 |
-
legend_title='DB Name',
|
1272 |
-
height=600 # puoi aumentare se ci sono tanti modelli
|
1273 |
-
)
|
1274 |
-
|
1275 |
-
return gr.Plot(fig, visible=True)
|
1276 |
-
"""
|
1277 |
|
1278 |
def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
|
1279 |
df = load_data_csv_es()
|
@@ -1289,7 +1322,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1289 |
df = df[df['db_category'].isin(target_cats)]
|
1290 |
df = df[df['model'].isin(selected_models)]
|
1291 |
|
1292 |
-
df =
|
1293 |
df = calculate_average_metrics(df, qatch_metrics)
|
1294 |
|
1295 |
# Calcola la media per db_category e modello
|
@@ -1410,14 +1443,14 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1410 |
|
1411 |
# RADAR OR BAR CHART BASED ON CATEGORY COUNT
|
1412 |
def plot_radar(df, selected_models, selected_metrics, selected_categories):
|
1413 |
-
if "
|
1414 |
-
selected_metrics = ["execution_accuracy", "
|
1415 |
else:
|
1416 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
1417 |
|
1418 |
# Filtro modelli e normalizzazione
|
1419 |
df = df[df['model'].isin(selected_models)]
|
1420 |
-
df =
|
1421 |
df = calculate_average_metrics(df, selected_metrics)
|
1422 |
|
1423 |
avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
|
@@ -1574,13 +1607,13 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1574 |
|
1575 |
# RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
|
1576 |
def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
|
1577 |
-
if "
|
1578 |
-
selected_metrics = ["execution_accuracy", "
|
1579 |
else:
|
1580 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
1581 |
|
1582 |
df = df[df['model'].isin(selected_models)]
|
1583 |
-
df =
|
1584 |
df = calculate_average_metrics(df, selected_metrics)
|
1585 |
|
1586 |
if isinstance(selected_category, str):
|
@@ -1743,6 +1776,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1743 |
|
1744 |
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
|
1745 |
def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
|
|
|
1746 |
if selected_models == "All":
|
1747 |
selected_models = models
|
1748 |
else:
|
@@ -1757,15 +1791,25 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1757 |
df = df[df['test_category'].isin(selected_categories)]
|
1758 |
|
1759 |
if "external" in selected_metrics:
|
1760 |
-
selected_metrics = ["execution_accuracy", "
|
1761 |
else:
|
1762 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
1763 |
|
1764 |
-
df =
|
1765 |
df = calculate_average_metrics(df, selected_metrics)
|
1766 |
-
|
1767 |
-
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
|
1768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1769 |
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
|
1770 |
|
1771 |
worst_cases_top_3 = worst_cases_df.head(3)
|
@@ -1778,14 +1822,24 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1778 |
medals = ["π₯", "π₯", "π₯"]
|
1779 |
|
1780 |
for i, row in worst_cases_top_3.iterrows():
|
1781 |
-
|
1782 |
-
|
1783 |
-
|
1784 |
-
|
1785 |
-
|
1786 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1787 |
|
1788 |
-
|
1789 |
|
1790 |
raw_answer = (
|
1791 |
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
@@ -1793,7 +1847,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1793 |
)
|
1794 |
|
1795 |
answer_str.append(raw_answer)
|
1796 |
-
|
1797 |
return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
|
1798 |
|
1799 |
def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
|
@@ -1803,7 +1857,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1803 |
# LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
|
1804 |
def plot_cumulative_flow(df, selected_models, max_points):
|
1805 |
df = df[df['model'].isin(selected_models)]
|
1806 |
-
df =
|
1807 |
|
1808 |
fig = go.Figure()
|
1809 |
|
@@ -1937,10 +1991,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1937 |
|
1938 |
external_metrics_dict = {
|
1939 |
"Execution Accuracy": "execution_accuracy",
|
1940 |
-
"Valid
|
1941 |
}
|
1942 |
|
1943 |
-
external_metric = ["execution_accuracy", "
|
1944 |
last_valid_external_metric_selection = external_metric.copy()
|
1945 |
def enforce_external_metric_selection(selected):
|
1946 |
global last_valid_external_metric_selection
|
@@ -1987,10 +2041,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
1987 |
|
1988 |
all_model_as_dic = {cat: [f"{cat}"] for cat in models}
|
1989 |
all_model_as_dic["All"] = models
|
1990 |
-
|
1991 |
-
#with gr.Blocks(theme=gr.themes.Default(primary_hue='blue')) as demo:
|
1992 |
-
|
1993 |
-
|
1994 |
|
1995 |
###########################
|
1996 |
# VISUALIZATION SECTION #
|
@@ -2029,7 +2079,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
2029 |
<span
|
2030 |
title="External metric info:
|
2031 |
Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
|
2032 |
-
Valid
|
2033 |
style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
|
2034 |
>External metric info βΉοΈ</span>
|
2035 |
</div>
|
@@ -2304,6 +2354,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
2304 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
2305 |
reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
|
2306 |
reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
|
2307 |
-
|
2308 |
|
2309 |
interface.launch(share = True)
|
|
|
12 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
13 |
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
|
14 |
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
|
15 |
+
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
16 |
from prediction import ModelPrediction
|
17 |
import utils_get_db_tables_info
|
18 |
import utilities as us
|
|
|
32 |
#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
|
33 |
pnp_path = "concatenated_output.csv"
|
34 |
PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
|
|
|
35 |
js_func = """
|
36 |
function refresh() {
|
37 |
const url = new URL(window.location);
|
|
|
42 |
}
|
43 |
}
|
44 |
"""
|
45 |
+
reset_flag = False
|
46 |
+
flag_TQA = False
|
47 |
|
48 |
with open('style.css', 'r') as file:
|
49 |
css = file.read()
|
|
|
66 |
### β€ **Non-Proprietary**
|
67 |
###     β Spider 1.0 π·οΈ"""
|
68 |
prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
|
69 |
+
prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as <answer> </answer>.\n Question \n {question}\n Database Schema\n {db_schema}\n"
|
70 |
+
|
71 |
|
72 |
input_data = {
|
73 |
'input_method': "",
|
|
|
96 |
#change path
|
97 |
input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
|
98 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
99 |
+
|
100 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
101 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
102 |
table2primary_key = {}
|
|
|
321 |
|
322 |
# Model selection button (initially disabled)
|
323 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
|
|
324 |
def update_table_list(data):
|
325 |
"""Dynamically updates the list of available tables and excluded ones."""
|
326 |
if isinstance(data, dict) and data:
|
|
|
461 |
default_checkbox
|
462 |
]
|
463 |
)
|
|
|
464 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
|
|
465 |
|
466 |
####################################
|
467 |
# MODEL SELECTION PART #
|
|
|
507 |
# Function to get selected models
|
508 |
def get_selected_models(*model_selections):
|
509 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
|
|
510 |
input_data['models'] = selected_models
|
511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
512 |
+
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state), gr.update(interactive=button_state)
|
513 |
|
514 |
# Add the Textbox to the interface
|
515 |
prompt = gr.TextArea(
|
|
|
517 |
placeholder=prompt_default,
|
518 |
elem_id="custom-textarea"
|
519 |
)
|
520 |
+
|
521 |
warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
|
522 |
|
523 |
# Submit button (initially disabled)
|
524 |
+
with gr.Row():
|
525 |
+
submit_models_button = gr.Button("Submit Models for NL2SQL task", interactive=False)
|
526 |
+
submit_models_button_tqa = gr.Button("Submit Models for TQA task", interactive=False)
|
527 |
|
528 |
def check_prompt(prompt):
|
529 |
#TODO
|
530 |
missing_elements = []
|
531 |
if(prompt==""):
|
532 |
+
input_data["prompt"] = prompt_default
|
533 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
534 |
else:
|
535 |
input_data["prompt"]=prompt
|
|
|
546 |
), gr.update(interactive=button_state)
|
547 |
return gr.update(visible=False), gr.update(interactive=button_state)
|
548 |
|
549 |
+
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, submit_models_button_tqa])
|
550 |
# Link checkboxes to selection events
|
551 |
for checkbox in model_checkboxes:
|
552 |
checkbox.change(
|
553 |
fn=get_selected_models,
|
554 |
inputs=model_checkboxes,
|
555 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
|
556 |
)
|
557 |
prompt.change(
|
558 |
fn=get_selected_models,
|
559 |
inputs=model_checkboxes,
|
560 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
|
561 |
)
|
562 |
|
563 |
submit_models_button.click(
|
|
|
566 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
567 |
)
|
568 |
|
569 |
+
submit_models_button_tqa.click(
|
570 |
+
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
|
571 |
+
inputs=model_checkboxes,
|
572 |
+
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
573 |
+
)
|
574 |
+
def change_flag():
|
575 |
+
global flag_TQA
|
576 |
+
flag_TQA = True
|
577 |
+
|
578 |
+
submit_models_button_tqa.click(fn = change_flag, inputs=[], outputs=[])
|
579 |
+
|
580 |
def enable_disable(enable):
|
581 |
return (
|
582 |
*[gr.update(interactive=enable) for _ in model_checkboxes],
|
|
|
587 |
gr.update(interactive=enable),
|
588 |
gr.update(interactive=enable),
|
589 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
590 |
+
gr.update(interactive=enable),
|
591 |
gr.update(interactive=enable)
|
592 |
)
|
593 |
|
|
|
605 |
default_checkbox,
|
606 |
table_selector,
|
607 |
*table_outputs,
|
608 |
+
open_model_selection,
|
609 |
+
submit_models_button_tqa
|
610 |
+
]
|
611 |
+
)
|
612 |
+
submit_models_button_tqa.click(
|
613 |
+
fn=enable_disable,
|
614 |
+
inputs=[gr.State(False)],
|
615 |
+
outputs=[
|
616 |
+
*model_checkboxes,
|
617 |
+
submit_models_button,
|
618 |
+
preview_output,
|
619 |
+
submit_button,
|
620 |
+
file_input,
|
621 |
+
default_checkbox,
|
622 |
+
table_selector,
|
623 |
+
*table_outputs,
|
624 |
+
open_model_selection,
|
625 |
+
submit_models_button_tqa
|
626 |
]
|
627 |
)
|
628 |
|
|
|
640 |
default_checkbox,
|
641 |
table_selector,
|
642 |
*table_outputs,
|
643 |
+
open_model_selection,
|
644 |
+
submit_models_button_tqa
|
645 |
]
|
646 |
)
|
647 |
|
|
|
692 |
{mirrored_symbols}
|
693 |
</div>
|
694 |
"""
|
695 |
+
|
696 |
+
def qatch_flow_nl_sql():
|
697 |
global reset_flag
|
698 |
+
global flag_TQA
|
699 |
predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
|
700 |
metrics_conc = pd.DataFrame()
|
701 |
columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
|
|
|
725 |
</div>
|
726 |
"""
|
727 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
728 |
+
|
729 |
prediction = row['predicted_sql']
|
730 |
|
731 |
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
|
|
|
733 |
<div style='font-size: 3rem'>β‘οΈ</div>
|
734 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
735 |
</div>
|
736 |
+
"""
|
737 |
+
|
738 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
739 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
740 |
metrics_conc = target_df
|
741 |
+
if 'valid_efficency_score' not in metrics_conc.columns:
|
742 |
+
metrics_conc['valid_efficency_score'] = metrics_conc['VES']
|
743 |
eval_text = generate_eval_text("End evaluation")
|
744 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
745 |
+
|
746 |
else:
|
|
|
747 |
orchestrator_generator = OrchestratorGenerator()
|
748 |
+
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=input_data['data']['selected_tables'])
|
749 |
+
|
750 |
+
#create target_df[target_answer]
|
751 |
+
if flag_TQA :
|
752 |
+
if (input_data["prompt"] == prompt_default):
|
753 |
+
input_data["prompt"] = prompt_default_tqa
|
754 |
+
target_df = us.extract_answer(target_df)
|
755 |
|
756 |
predictor = ModelPrediction()
|
757 |
reset_flag = False
|
|
|
772 |
</div>
|
773 |
"""
|
774 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
775 |
+
#samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
776 |
+
model_to_send = None if not flag_TQA else model
|
777 |
+
|
778 |
+
|
779 |
+
db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
780 |
+
db_id = input_data["db_name"],
|
781 |
+
base_path = input_data["data_path"],
|
782 |
+
normalize=False,
|
783 |
+
sql=row["query"],
|
784 |
+
get_insert_into=True,
|
785 |
+
model = model_to_send,
|
786 |
+
prompt = input_data["prompt"].format(question=question, db_schema=""),
|
787 |
)
|
788 |
|
789 |
#prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
|
|
|
791 |
#PREDICTION SQL
|
792 |
|
793 |
# TODO add button for QA or SP and pass to .make_prediction parameter TASK
|
794 |
+
if flag_TQA: task="QA"
|
795 |
+
else: task="SP"
|
796 |
+
start_time = time.time()
|
797 |
response = predictor.make_prediction(
|
798 |
question=question,
|
799 |
+
db_schema=db_schema_text,
|
800 |
model_name=model,
|
801 |
prompt=f"{prompt_to_send}",
|
802 |
+
task=task
|
803 |
)
|
804 |
prediction = response['response_parsed']
|
805 |
price = response['cost']
|
806 |
answer = response['response']
|
807 |
|
808 |
end_time = time.time()
|
809 |
+
if flag_TQA:
|
810 |
+
task_string = "Answer"
|
811 |
+
else:
|
812 |
+
task_string = "SQL"
|
813 |
+
|
814 |
+
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted {task_string}:</div>
|
815 |
<div style='display: flex; align-items: center;'>
|
816 |
<div style='font-size: 3rem'>β‘οΈ</div>
|
817 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
|
|
826 |
'query': row["query"],
|
827 |
'db_path': input_data["data_path"],
|
828 |
'price':price,
|
829 |
+
'answer': answer,
|
830 |
'number_question':count,
|
831 |
+
'target_answer' : row["target_answer"] if flag_TQA else None,
|
832 |
+
|
833 |
}]).dropna(how="all") # Remove only completely empty rows
|
834 |
count=count+1
|
835 |
# TODO: use a for loop
|
836 |
+
if (flag_TQA) :
|
837 |
+
new_row['predicted_answer'] = prediction
|
838 |
for col in target_df.columns:
|
839 |
if col not in new_row.columns:
|
840 |
new_row[col] = row[col]
|
|
|
841 |
# Update model's prediction dataframe incrementally
|
842 |
if not new_row.empty:
|
843 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
844 |
|
845 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
846 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
|
|
847 |
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
848 |
# END
|
849 |
eval_text = generate_eval_text("Evaluation")
|
850 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
851 |
+
|
852 |
evaluator = OrchestratorEvaluator()
|
853 |
+
|
854 |
for model in input_data["models"]:
|
855 |
+
if not flag_TQA:
|
856 |
+
metrics_df_model = evaluator.evaluate_df(
|
857 |
+
df=predictions_dict[model],
|
858 |
+
target_col_name="query",
|
859 |
+
prediction_col_name="predicted_sql",
|
860 |
+
db_path_name="db_path"
|
861 |
+
)
|
862 |
+
else:
|
863 |
+
metrics_df_model = us.evaluate_answer(predictions_dict[model])
|
864 |
metrics_df_model['model'] = model
|
865 |
metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
|
866 |
+
|
867 |
+
if 'valid_efficency_score' not in metrics_conc.columns and 'VES' in metrics_conc.columns:
|
868 |
+
metrics_conc['valid_efficency_score'] = metrics_conc['VES']
|
869 |
+
|
870 |
eval_text = generate_eval_text("End evaluation")
|
871 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
872 |
|
|
|
902 |
gr.Markdown(f"**Results for {model}**")
|
903 |
tab_dict[model] = tab
|
904 |
dataframe_per_model[model] = gr.DataFrame()
|
905 |
+
#TODO download metrics per model
|
906 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
907 |
|
908 |
evaluation_loading = gr.Markdown()
|
|
|
915 |
inputs=[],
|
916 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
917 |
)
|
918 |
+
submit_models_button_tqa.click(
|
919 |
+
change_tab,
|
920 |
+
inputs=[],
|
921 |
+
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
922 |
+
)
|
923 |
|
924 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
925 |
metrics_df = gr.DataFrame(visible=False)
|
926 |
metrics_df_out = gr.DataFrame(visible=False)
|
927 |
|
928 |
submit_models_button.click(
|
929 |
+
fn=qatch_flow_nl_sql,
|
930 |
+
inputs=[],
|
931 |
+
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
932 |
+
)
|
933 |
+
|
934 |
+
submit_models_button_tqa.click(
|
935 |
+
fn=qatch_flow_nl_sql,
|
936 |
inputs=[],
|
937 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
938 |
)
|
|
|
941 |
fn=lambda: gr.update(value=input_data),
|
942 |
outputs=[selected_models_display]
|
943 |
)
|
944 |
+
submit_models_button_tqa.click(
|
945 |
+
fn=lambda: gr.update(value=input_data),
|
946 |
+
outputs=[selected_models_display]
|
947 |
+
)
|
948 |
|
949 |
# Works for METRICS
|
950 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
|
|
967 |
fn=lambda: gr.update(visible=False),
|
968 |
outputs=[download_metrics]
|
969 |
)
|
970 |
+
submit_models_button_tqa.click(
|
971 |
+
fn=lambda: gr.update(visible=False),
|
972 |
+
outputs=[download_metrics]
|
973 |
+
)
|
974 |
|
975 |
def refresh():
|
976 |
global reset_flag
|
977 |
+
global flag_TQA
|
978 |
reset_flag = True
|
979 |
+
flag_TQA = False
|
980 |
|
981 |
reset_data = gr.Button("Back to upload data section", interactive=True)
|
982 |
|
|
|
1002 |
default_checkbox,
|
1003 |
table_selector,
|
1004 |
*table_outputs,
|
1005 |
+
open_model_selection,
|
1006 |
+
submit_models_button_tqa
|
1007 |
]
|
1008 |
)
|
1009 |
+
|
1010 |
+
|
1011 |
##########################################
|
1012 |
# METRICS VISUALIZATION SECTION #
|
1013 |
##########################################
|
|
|
1022 |
####################################
|
1023 |
|
1024 |
def load_data_csv_es():
|
1025 |
+
|
1026 |
if input_data["input_method"]=="default":
|
1027 |
+
global flag_TQA
|
1028 |
df = pd.read_csv(pnp_path)
|
1029 |
df = df[df['model'].isin(input_data["models"])]
|
1030 |
df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
|
|
|
1035 |
df['model'] = df['model'].replace('llama-70', 'Llama-70B')
|
1036 |
df['model'] = df['model'].replace('llama-8', 'Llama-8B')
|
1037 |
df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
|
1038 |
+
if (flag_TQA) : flag_TQA = False #TODO delete after make pred
|
1039 |
return df
|
1040 |
return metrics_df_out
|
1041 |
|
|
|
1078 |
|
1079 |
DB_CATEGORY_COLORS = generate_db_category_colors()
|
1080 |
|
1081 |
+
def normalize_valid_efficency_score(df):
|
1082 |
+
df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0)
|
1083 |
+
df['valid_efficency_score'] = df['valid_efficency_score'].astype(int)
|
1084 |
+
min_val = df['valid_efficency_score'].min()
|
1085 |
+
max_val = df['valid_efficency_score'].max()
|
|
|
|
|
1086 |
|
1087 |
+
if min_val == max_val :
|
1088 |
+
# All values are equal, so for avoid division by zero, we set the score to 1/0
|
1089 |
+
if min_val == None:
|
1090 |
+
df['valid_efficency_score'] = 0
|
1091 |
+
else:
|
1092 |
+
df['valid_efficency_score'] = 1.0
|
1093 |
else:
|
1094 |
+
df['valid_efficency_score'] = (
|
1095 |
+
df['valid_efficency_score'] - min_val
|
1096 |
) / (max_val - min_val)
|
1097 |
|
1098 |
return df
|
|
|
1105 |
# BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
|
1106 |
def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
|
1107 |
df = df[df['model'].isin(selected_models)]
|
1108 |
+
df = normalize_valid_efficency_score(df)
|
1109 |
|
1110 |
# Mappatura nomi leggibili -> tecnici
|
1111 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
|
|
1222 |
selected_models = [selected_models]
|
1223 |
|
1224 |
df = df[df['model'].isin(selected_models)]
|
1225 |
+
df = normalize_valid_efficency_score(df)
|
1226 |
|
1227 |
# Converti nomi leggibili -> tecnici
|
1228 |
qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
|
|
|
1307 |
)
|
1308 |
|
1309 |
return gr.Plot(fig, visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1310 |
|
1311 |
def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
|
1312 |
df = load_data_csv_es()
|
|
|
1322 |
df = df[df['db_category'].isin(target_cats)]
|
1323 |
df = df[df['model'].isin(selected_models)]
|
1324 |
|
1325 |
+
df = normalize_valid_efficency_score(df)
|
1326 |
df = calculate_average_metrics(df, qatch_metrics)
|
1327 |
|
1328 |
# Calcola la media per db_category e modello
|
|
|
1443 |
|
1444 |
# RADAR OR BAR CHART BASED ON CATEGORY COUNT
|
1445 |
def plot_radar(df, selected_models, selected_metrics, selected_categories):
|
1446 |
+
if "External" in selected_metrics:
|
1447 |
+
selected_metrics = ["execution_accuracy", "valid_efficency_score"]
|
1448 |
else:
|
1449 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
1450 |
|
1451 |
# Filtro modelli e normalizzazione
|
1452 |
df = df[df['model'].isin(selected_models)]
|
1453 |
+
df = normalize_valid_efficency_score(df)
|
1454 |
df = calculate_average_metrics(df, selected_metrics)
|
1455 |
|
1456 |
avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
|
|
|
1607 |
|
1608 |
# RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
|
1609 |
def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
|
1610 |
+
if "External" in selected_metrics:
|
1611 |
+
selected_metrics = ["execution_accuracy", "valid_efficency_score"]
|
1612 |
else:
|
1613 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
1614 |
|
1615 |
df = df[df['model'].isin(selected_models)]
|
1616 |
+
df = normalize_valid_efficency_score(df)
|
1617 |
df = calculate_average_metrics(df, selected_metrics)
|
1618 |
|
1619 |
if isinstance(selected_category, str):
|
|
|
1776 |
|
1777 |
# RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
|
1778 |
def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
|
1779 |
+
global flag_TQA
|
1780 |
if selected_models == "All":
|
1781 |
selected_models = models
|
1782 |
else:
|
|
|
1791 |
df = df[df['test_category'].isin(selected_categories)]
|
1792 |
|
1793 |
if "external" in selected_metrics:
|
1794 |
+
selected_metrics = ["execution_accuracy", "valid_efficency_score"]
|
1795 |
else:
|
1796 |
selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
|
1797 |
|
1798 |
+
df = normalize_valid_efficency_score(df)
|
1799 |
df = calculate_average_metrics(df, selected_metrics)
|
|
|
|
|
1800 |
|
1801 |
+
if flag_TQA:
|
1802 |
+
df["target_answer"] = df["target_answer"].apply(
|
1803 |
+
lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
|
1804 |
+
)
|
1805 |
+
df["predicted_answer"] = df["predicted_answer"].apply(
|
1806 |
+
lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
|
1807 |
+
)
|
1808 |
+
|
1809 |
+
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
|
1810 |
+
else:
|
1811 |
+
worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
|
1812 |
+
|
1813 |
worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
|
1814 |
|
1815 |
worst_cases_top_3 = worst_cases_df.head(3)
|
|
|
1822 |
medals = ["π₯", "π₯", "π₯"]
|
1823 |
|
1824 |
for i, row in worst_cases_top_3.iterrows():
|
1825 |
+
if flag_TQA:
|
1826 |
+
entry = (
|
1827 |
+
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
1828 |
+
f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
|
1829 |
+
f"<span style='font-size:16px;'>- <b>Original Answer:</b> `{row['target_answer']}`</span> \n"
|
1830 |
+
f"<span style='font-size:16px;'>- <b>Predicted Answer:</b> `{row['predicted_answer']}`</span> \n\n"
|
1831 |
+
)
|
1832 |
+
|
1833 |
+
worst_str.append(entry)
|
1834 |
+
else:
|
1835 |
+
entry = (
|
1836 |
+
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
1837 |
+
f"<span style='font-size:16px;'>- <b>Question:</b> {row['question']}</span> \n"
|
1838 |
+
f"<span style='font-size:16px;'>- <b>Original Query:</b> `{row['query']}`</span> \n"
|
1839 |
+
f"<span style='font-size:16px;'>- <b>Predicted SQL:</b> `{row['predicted_sql']}`</span> \n\n"
|
1840 |
+
)
|
1841 |
|
1842 |
+
worst_str.append(entry)
|
1843 |
|
1844 |
raw_answer = (
|
1845 |
f"<span style='font-size:18px;'><b>{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']}</b> ({row['avg_metric']})</span> \n"
|
|
|
1847 |
)
|
1848 |
|
1849 |
answer_str.append(raw_answer)
|
1850 |
+
|
1851 |
return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
|
1852 |
|
1853 |
def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
|
|
|
1857 |
# LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
|
1858 |
def plot_cumulative_flow(df, selected_models, max_points):
|
1859 |
df = df[df['model'].isin(selected_models)]
|
1860 |
+
df = normalize_valid_efficency_score(df)
|
1861 |
|
1862 |
fig = go.Figure()
|
1863 |
|
|
|
1991 |
|
1992 |
external_metrics_dict = {
|
1993 |
"Execution Accuracy": "execution_accuracy",
|
1994 |
+
"Valid Efficency Score": "valid_efficency_score"
|
1995 |
}
|
1996 |
|
1997 |
+
external_metric = ["execution_accuracy", "valid_efficency_score"]
|
1998 |
last_valid_external_metric_selection = external_metric.copy()
|
1999 |
def enforce_external_metric_selection(selected):
|
2000 |
global last_valid_external_metric_selection
|
|
|
2041 |
|
2042 |
all_model_as_dic = {cat: [f"{cat}"] for cat in models}
|
2043 |
all_model_as_dic["All"] = models
|
|
|
|
|
|
|
|
|
2044 |
|
2045 |
###########################
|
2046 |
# VISUALIZATION SECTION #
|
|
|
2079 |
<span
|
2080 |
title="External metric info:
|
2081 |
Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
|
2082 |
+
Valid Efficency Score: Evaluates the efficency of a query by combining execution time and correctness. It rewards queries that are both accurate and fast."
|
2083 |
style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
|
2084 |
>External metric info βΉοΈ</span>
|
2085 |
</div>
|
|
|
2354 |
reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
|
2355 |
reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
|
2356 |
reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
|
2357 |
+
|
2358 |
|
2359 |
interface.launch(share = True)
|
concatenated_output.csv
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
cell_precision,sql_tag,tuple_cardinality,answer,predicted_sql,db_category,tuple_constraint,VES,number_question,
|
2 |
1.0,DISTINCT-SINGLE,1.0,"```sql
|
3 |
SELECT DISTINCT WAREHOUSE_LOAD_DATE
|
4 |
FROM FAC_BUILDING_ADDRESS;
|
|
|
1 |
+
cell_precision,sql_tag,tuple_cardinality,answer,predicted_sql,db_category,tuple_constraint,VES,number_question,valid_efficency_score,tbl_name,tuple_order,time,price,question,model,cell_recall,db_path,execution_accuracy,test_category,query
|
2 |
1.0,DISTINCT-SINGLE,1.0,"```sql
|
3 |
SELECT DISTINCT WAREHOUSE_LOAD_DATE
|
4 |
FROM FAC_BUILDING_ADDRESS;
|
utilities.py
CHANGED
@@ -6,6 +6,11 @@ import sqlite3
|
|
6 |
import gradio as gr
|
7 |
import os
|
8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
|
|
|
|
|
|
|
|
|
|
9 |
def extract_tables(file_path):
|
10 |
conn = sqlite3.connect(file_path)
|
11 |
cursor = conn.cursor()
|
@@ -26,7 +31,7 @@ def extract_dataframes(file_path):
|
|
26 |
return dfs
|
27 |
|
28 |
def carica_sqlite(file_path, db_id):
|
29 |
-
data_output = {'data_frames': extract_dataframes(file_path),'db':SqliteConnector(relative_db_path=file_path, db_name=db_id)}
|
30 |
return data_output
|
31 |
|
32 |
# Funzione per leggere un file CSV
|
@@ -113,3 +118,94 @@ def generate_some_samples(file_path, tbl_name):
|
|
113 |
def load_tables_dict_from_pkl(file_path):
|
114 |
with open(file_path, 'rb') as f:
|
115 |
return pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import gradio as gr
|
7 |
import os
|
8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
9 |
+
from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
|
10 |
+
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
11 |
+
#import tiktoken
|
12 |
+
from transformers import AutoTokenizer
|
13 |
+
|
14 |
def extract_tables(file_path):
|
15 |
conn = sqlite3.connect(file_path)
|
16 |
cursor = conn.cursor()
|
|
|
31 |
return dfs
|
32 |
|
33 |
def carica_sqlite(file_path, db_id):
|
34 |
+
data_output = {'data_frames': extract_dataframes(file_path),'db': SqliteConnector(relative_db_path=file_path, db_name=db_id)}
|
35 |
return data_output
|
36 |
|
37 |
# Funzione per leggere un file CSV
|
|
|
118 |
def load_tables_dict_from_pkl(file_path):
|
119 |
with open(file_path, 'rb') as f:
|
120 |
return pickle.load(f)
|
121 |
+
|
122 |
+
def extract_tables_dict(pnp_path):
|
123 |
+
return load_tables_dict_from_pkl('tables_dict_beaver.pkl')
|
124 |
+
tables_dict = {}
|
125 |
+
with open(pnp_path, mode='r', encoding='utf-8') as file:
|
126 |
+
reader = csv.DictReader(file)
|
127 |
+
tbl_db_pairs = set() # Use a set to avoid duplicates
|
128 |
+
for row in reader:
|
129 |
+
tbl_name = row.get("tbl_name")
|
130 |
+
db_path = row.get("db_path")
|
131 |
+
if tbl_name and db_path:
|
132 |
+
tbl_db_pairs.add((tbl_name, db_path)) # Add the pair to the set
|
133 |
+
for tbl_name, db_path in list(tbl_db_pairs):
|
134 |
+
if tbl_name and db_path:
|
135 |
+
connector = sqlite3.connect(db_path)
|
136 |
+
query = f"SELECT * FROM {tbl_name} LIMIT 5"
|
137 |
+
try:
|
138 |
+
df = pd.read_sql_query(query, connector)
|
139 |
+
tables_dict[tbl_name] = df
|
140 |
+
except Exception as e:
|
141 |
+
tables_dict[tbl_name] = pd.DataFrame({"Error": [str(e)]}) # DataFrame con messaggio di errore
|
142 |
+
#with open('tables_dict_beaver.pkl', 'wb') as f:
|
143 |
+
# pickle.dump(tables_dict, f)
|
144 |
+
return tables_dict
|
145 |
+
|
146 |
+
|
147 |
+
def extract_answer(df):
|
148 |
+
if "query" not in df.columns or "db_path" not in df.columns:
|
149 |
+
raise ValueError("The DataFrame must contain 'query' and 'data_path' columns.")
|
150 |
+
|
151 |
+
answers = []
|
152 |
+
for _, row in df.iterrows():
|
153 |
+
query = row["query"]
|
154 |
+
db_path = row["db_path"]
|
155 |
+
try:
|
156 |
+
conn = SqliteConnector(relative_db_path = db_path , db_name= "db")
|
157 |
+
answer = eva._utils_run_query_if_str(query, conn)
|
158 |
+
answers.append(answer)
|
159 |
+
except Exception as e:
|
160 |
+
answers.append(f"Error: {e}")
|
161 |
+
|
162 |
+
df["target_answer"] = answers
|
163 |
+
return df
|
164 |
+
|
165 |
+
evaluator = {
|
166 |
+
"cell_precision": CellPrecision(),
|
167 |
+
"cell_recall": CellRecall(),
|
168 |
+
"tuple_cardinality": TupleCardinality(),
|
169 |
+
"tuple_order": TupleOrder(),
|
170 |
+
"tuple_constraint": TupleConstraint(),
|
171 |
+
"execution_accuracy": ExecutionAccuracy(),
|
172 |
+
"valid_efficency_score": ValidEfficiencyScore()
|
173 |
+
}
|
174 |
+
|
175 |
+
def evaluate_answer(df):
|
176 |
+
for metric_name, metric in evaluator.items():
|
177 |
+
results = []
|
178 |
+
for _, row in df.iterrows():
|
179 |
+
target = row["target_answer"]
|
180 |
+
predicted = row["predicted_answer"]
|
181 |
+
try:
|
182 |
+
result = metric.run_metric(target = target, prediction = predicted)
|
183 |
+
except Exception as e:
|
184 |
+
result = None
|
185 |
+
results.append(result)
|
186 |
+
df[metric_name] = results
|
187 |
+
return df
|
188 |
+
|
189 |
+
models = [
|
190 |
+
"gpt-4o-mini",
|
191 |
+
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
|
192 |
+
]
|
193 |
+
|
194 |
+
def crop_entries_per_token(entries_list, model, prompt: str | None = None):
|
195 |
+
#open_ai_models = ["gpt-3.5", "gpt-4o-mini"]
|
196 |
+
dimension = 2048
|
197 |
+
#enties_string = [", ".join(map(str, entry)) for entry in entries_list]
|
198 |
+
if prompt:
|
199 |
+
entries_string = prompt.join(entries_list)
|
200 |
+
else:
|
201 |
+
entries_string = " ".join(entries_list)
|
202 |
+
#if model in ["deepseek-ai/DeepSeek-R1-Distill-Llama-70B" ,"gpt-4o-mini" ] :
|
203 |
+
#tokenizer = tiktoken.encoding_for_model("gpt-4o-mini")
|
204 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B")
|
205 |
+
|
206 |
+
tokens = tokenizer.encode(entries_string)
|
207 |
+
number_of_tokens = len(tokens)
|
208 |
+
if number_of_tokens > dimension and len(entries_list) > 4:
|
209 |
+
entries_list = entries_list[:round(len(entries_list)/2)]
|
210 |
+
entries_list = crop_entries_per_token(entries_list, model)
|
211 |
+
return entries_list
|
utils_get_db_tables_info.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import os
|
2 |
import sqlite3
|
3 |
import re
|
4 |
-
|
5 |
|
6 |
def utils_extract_db_schema_as_string(
|
7 |
-
db_id, base_path, normalize=False, sql: str | None = None, get_insert_into: bool = False
|
8 |
):
|
9 |
"""
|
10 |
Extracts the full schema of an SQLite database into a single string.
|
@@ -19,7 +19,7 @@ def utils_extract_db_schema_as_string(
|
|
19 |
cursor = connection.cursor()
|
20 |
|
21 |
# Get the schema entries based on the provided SQL query
|
22 |
-
schema_entries = _get_schema_entries(cursor, sql, get_insert_into)
|
23 |
|
24 |
# Combine all schema definitions into a single string
|
25 |
schema_string = _combine_schema_entries(schema_entries, normalize)
|
@@ -28,7 +28,7 @@ def utils_extract_db_schema_as_string(
|
|
28 |
|
29 |
|
30 |
|
31 |
-
def _get_schema_entries(cursor, sql=None, get_insert_into=False):
|
32 |
"""
|
33 |
Retrieves schema entries and optionally data entries from the SQLite database.
|
34 |
|
@@ -62,11 +62,17 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False):
|
|
62 |
column_names = [description[0] for description in cursor.description]
|
63 |
|
64 |
# Generate INSERT INTO statements for each row
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
68 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
69 |
entries.append(insert_stmt)
|
|
|
|
|
70 |
|
71 |
return entries
|
72 |
|
|
|
1 |
import os
|
2 |
import sqlite3
|
3 |
import re
|
4 |
+
import utilities as us
|
5 |
|
6 |
def utils_extract_db_schema_as_string(
|
7 |
+
db_id, base_path, model : str | None = None , normalize=False, sql: str | None = None, get_insert_into: bool = False, prompt : str | None = None
|
8 |
):
|
9 |
"""
|
10 |
Extracts the full schema of an SQLite database into a single string.
|
|
|
19 |
cursor = connection.cursor()
|
20 |
|
21 |
# Get the schema entries based on the provided SQL query
|
22 |
+
schema_entries = _get_schema_entries(cursor, sql, get_insert_into, model, prompt)
|
23 |
|
24 |
# Combine all schema definitions into a single string
|
25 |
schema_string = _combine_schema_entries(schema_entries, normalize)
|
|
|
28 |
|
29 |
|
30 |
|
31 |
+
def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | None = None, prompt : str | None = None):
|
32 |
"""
|
33 |
Retrieves schema entries and optionally data entries from the SQLite database.
|
34 |
|
|
|
62 |
column_names = [description[0] for description in cursor.description]
|
63 |
|
64 |
# Generate INSERT INTO statements for each row
|
65 |
+
if model==None :
|
66 |
+
max_len=3
|
67 |
+
else:
|
68 |
+
max_len = len(rows)
|
69 |
+
|
70 |
+
for row in rows[:max_len]:
|
71 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
72 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
73 |
entries.append(insert_stmt)
|
74 |
+
|
75 |
+
if model != None : entries = us.crop_entries_per_token(entries, model, prompt)
|
76 |
|
77 |
return entries
|
78 |
|