File size: 4,255 Bytes
890025a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import streamlit as st
from .hyperparametrs import get_hyperparams_ui
import pickle
from .train import train_model
# Model Training Tab
def model_training_tab(df):
# Ensure we have session state for model training
if "target_column" not in st.session_state:
st.session_state.target_column = df.columns[0] if not df.empty else None
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
st.subheader("π Model Training")
# Use session state to maintain selection across reruns
target_column = st.selectbox(
"π― Select Target Column (Y)",
df.columns,
index=list(df.columns).index(st.session_state.target_column) if st.session_state.target_column in df.columns else 0,
key="target_column_select"
)
# Update session state after selection
st.session_state.target_column = target_column
# Infer task type automatically
task_type = "classification" if df[target_column].dtype == "object" or df[target_column].nunique() <= 10 else "regression"
st.write(f"π Detected Task Type: **{task_type.capitalize()}**")
model_options = {
"classification": ["Random Forest", "Logistic Regression", "XGBoost" , "Support Vector Classifier", "Decision Tree Classifier", "K-Nearest Neighbors Classifier", "Gradient Boosting Classifier", "AdaBoost Classifier", "Gaussian Naive Bayes", "Quadratic Discriminant Analysis", "Linear Discriminant Analysis"],
"regression": ["Linear Regression", "Random Forest Regressor", "XGBoost Regressor" , "Support Vector Regressor", "Decision Tree Regressor", "K-Nearest Neighbors Regressor", "ElasticNet", "Gradient Boosting Regressor", "AdaBoost Regressor", "Bayesian Ridge" , "Ridge Regression", "Lasso Regression"],
}
# Initialize selected model if not already set or if task type changed
if st.session_state.selected_model not in model_options[task_type]:
st.session_state.selected_model = model_options[task_type][0]
# Use session state to maintain selection across reruns
selected_model_name = st.selectbox(
"π€ Choose Model",
model_options[task_type],
index=model_options[task_type].index(st.session_state.selected_model),
key="selected_model_select"
)
# Update session state after selection
st.session_state.selected_model = selected_model_name
st.markdown("### π§ Hyperparameters")
hyperparams = get_hyperparams_ui(selected_model_name)
# Use a unique key for the button to avoid conflicts
if st.button("π Train Model", key="train_model_button_unique"):
with st.spinner("Training in progress... β³"):
trained_model = train_model(df, target_column, task_type, selected_model_name, hyperparams)
st.success("β
Model trained successfully!")
st.session_state.trained_model = trained_model
st.session_state.model_trained = True
# Note: test_results_calculated is already reset in train_model function
if "trained_model" in st.session_state:
st.markdown("### π₯ Download Trained Model")
# Use a safer approach for file operations with proper cleanup
try:
# Use a temporary file that will be automatically cleaned up
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file:
pickle.dump(st.session_state.trained_model, temp_file)
temp_file_path = temp_file.name
# Read the file for download
with open(temp_file_path, "rb") as f:
st.download_button(
label="π₯ Download Model",
data=f,
file_name="trained_model.pkl",
mime="application/octet-stream",
)
# Clean up the temporary file
import os
try:
os.unlink(temp_file_path)
except:
pass # Silently handle deletion errors
except Exception as e:
st.error(f"Error preparing model for download: {str(e)}")
|