|
import streamlit as st |
|
from .hyperparametrs import get_hyperparams_ui |
|
import pickle |
|
from .train import train_model |
|
|
|
|
|
def model_training_tab(df): |
|
|
|
if "target_column" not in st.session_state: |
|
st.session_state.target_column = df.columns[0] if not df.empty else None |
|
|
|
if "selected_model" not in st.session_state: |
|
st.session_state.selected_model = None |
|
|
|
st.subheader("π Model Training") |
|
|
|
|
|
target_column = st.selectbox( |
|
"π― Select Target Column (Y)", |
|
df.columns, |
|
index=list(df.columns).index(st.session_state.target_column) if st.session_state.target_column in df.columns else 0, |
|
key="target_column_select" |
|
) |
|
|
|
|
|
st.session_state.target_column = target_column |
|
|
|
|
|
task_type = "classification" if df[target_column].dtype == "object" or df[target_column].nunique() <= 10 else "regression" |
|
st.write(f"π Detected Task Type: **{task_type.capitalize()}**") |
|
|
|
model_options = { |
|
"classification": ["Random Forest", "Logistic Regression", "XGBoost" , "Support Vector Classifier", "Decision Tree Classifier", "K-Nearest Neighbors Classifier", "Gradient Boosting Classifier", "AdaBoost Classifier", "Gaussian Naive Bayes", "Quadratic Discriminant Analysis", "Linear Discriminant Analysis"], |
|
"regression": ["Linear Regression", "Random Forest Regressor", "XGBoost Regressor" , "Support Vector Regressor", "Decision Tree Regressor", "K-Nearest Neighbors Regressor", "ElasticNet", "Gradient Boosting Regressor", "AdaBoost Regressor", "Bayesian Ridge" , "Ridge Regression", "Lasso Regression"], |
|
} |
|
|
|
|
|
if st.session_state.selected_model not in model_options[task_type]: |
|
st.session_state.selected_model = model_options[task_type][0] |
|
|
|
|
|
selected_model_name = st.selectbox( |
|
"π€ Choose Model", |
|
model_options[task_type], |
|
index=model_options[task_type].index(st.session_state.selected_model), |
|
key="selected_model_select" |
|
) |
|
|
|
|
|
st.session_state.selected_model = selected_model_name |
|
|
|
st.markdown("### π§ Hyperparameters") |
|
hyperparams = get_hyperparams_ui(selected_model_name) |
|
|
|
|
|
if st.button("π Train Model", key="train_model_button_unique"): |
|
with st.spinner("Training in progress... β³"): |
|
trained_model = train_model(df, target_column, task_type, selected_model_name, hyperparams) |
|
st.success("β
Model trained successfully!") |
|
st.session_state.trained_model = trained_model |
|
st.session_state.model_trained = True |
|
|
|
|
|
|
|
if "trained_model" in st.session_state: |
|
st.markdown("### π₯ Download Trained Model") |
|
|
|
|
|
try: |
|
|
|
import tempfile |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file: |
|
pickle.dump(st.session_state.trained_model, temp_file) |
|
temp_file_path = temp_file.name |
|
|
|
|
|
with open(temp_file_path, "rb") as f: |
|
st.download_button( |
|
label="π₯ Download Model", |
|
data=f, |
|
file_name="trained_model.pkl", |
|
mime="application/octet-stream", |
|
) |
|
|
|
|
|
import os |
|
try: |
|
os.unlink(temp_file_path) |
|
except: |
|
pass |
|
|
|
except Exception as e: |
|
st.error(f"Error preparing model for download: {str(e)}") |
|
|