AutoML / src /training /model_training.py
akash
all files
890025a
import streamlit as st
from .hyperparametrs import get_hyperparams_ui
import pickle
from .train import train_model
# Model Training Tab
def model_training_tab(df):
# Ensure we have session state for model training
if "target_column" not in st.session_state:
st.session_state.target_column = df.columns[0] if not df.empty else None
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
st.subheader("πŸ“Œ Model Training")
# Use session state to maintain selection across reruns
target_column = st.selectbox(
"🎯 Select Target Column (Y)",
df.columns,
index=list(df.columns).index(st.session_state.target_column) if st.session_state.target_column in df.columns else 0,
key="target_column_select"
)
# Update session state after selection
st.session_state.target_column = target_column
# Infer task type automatically
task_type = "classification" if df[target_column].dtype == "object" or df[target_column].nunique() <= 10 else "regression"
st.write(f"πŸ” Detected Task Type: **{task_type.capitalize()}**")
model_options = {
"classification": ["Random Forest", "Logistic Regression", "XGBoost" , "Support Vector Classifier", "Decision Tree Classifier", "K-Nearest Neighbors Classifier", "Gradient Boosting Classifier", "AdaBoost Classifier", "Gaussian Naive Bayes", "Quadratic Discriminant Analysis", "Linear Discriminant Analysis"],
"regression": ["Linear Regression", "Random Forest Regressor", "XGBoost Regressor" , "Support Vector Regressor", "Decision Tree Regressor", "K-Nearest Neighbors Regressor", "ElasticNet", "Gradient Boosting Regressor", "AdaBoost Regressor", "Bayesian Ridge" , "Ridge Regression", "Lasso Regression"],
}
# Initialize selected model if not already set or if task type changed
if st.session_state.selected_model not in model_options[task_type]:
st.session_state.selected_model = model_options[task_type][0]
# Use session state to maintain selection across reruns
selected_model_name = st.selectbox(
"πŸ€– Choose Model",
model_options[task_type],
index=model_options[task_type].index(st.session_state.selected_model),
key="selected_model_select"
)
# Update session state after selection
st.session_state.selected_model = selected_model_name
st.markdown("### πŸ”§ Hyperparameters")
hyperparams = get_hyperparams_ui(selected_model_name)
# Use a unique key for the button to avoid conflicts
if st.button("πŸš€ Train Model", key="train_model_button_unique"):
with st.spinner("Training in progress... ⏳"):
trained_model = train_model(df, target_column, task_type, selected_model_name, hyperparams)
st.success("βœ… Model trained successfully!")
st.session_state.trained_model = trained_model
st.session_state.model_trained = True
# Note: test_results_calculated is already reset in train_model function
if "trained_model" in st.session_state:
st.markdown("### πŸ“₯ Download Trained Model")
# Use a safer approach for file operations with proper cleanup
try:
# Use a temporary file that will be automatically cleaned up
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file:
pickle.dump(st.session_state.trained_model, temp_file)
temp_file_path = temp_file.name
# Read the file for download
with open(temp_file_path, "rb") as f:
st.download_button(
label="πŸ“₯ Download Model",
data=f,
file_name="trained_model.pkl",
mime="application/octet-stream",
)
# Clean up the temporary file
import os
try:
os.unlink(temp_file_path)
except:
pass # Silently handle deletion errors
except Exception as e:
st.error(f"Error preparing model for download: {str(e)}")