File size: 4,255 Bytes
890025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st 
from .hyperparametrs import get_hyperparams_ui 
import pickle
from .train import train_model

#  Model Training Tab
def model_training_tab(df):
    # Ensure we have session state for model training
    if "target_column" not in st.session_state:
        st.session_state.target_column = df.columns[0] if not df.empty else None
    
    if "selected_model" not in st.session_state:
        st.session_state.selected_model = None
    
    st.subheader("πŸ“Œ Model Training")

    # Use session state to maintain selection across reruns
    target_column = st.selectbox(
        "🎯 Select Target Column (Y)", 
        df.columns, 
        index=list(df.columns).index(st.session_state.target_column) if st.session_state.target_column in df.columns else 0,
        key="target_column_select"
    )
    
    # Update session state after selection
    st.session_state.target_column = target_column

    # Infer task type automatically
    task_type = "classification" if df[target_column].dtype == "object" or df[target_column].nunique() <= 10 else "regression"
    st.write(f"πŸ” Detected Task Type: **{task_type.capitalize()}**")

    model_options = {
        "classification": ["Random Forest", "Logistic Regression", "XGBoost" , "Support Vector Classifier", "Decision Tree Classifier", "K-Nearest Neighbors Classifier", "Gradient Boosting Classifier", "AdaBoost Classifier", "Gaussian Naive Bayes", "Quadratic Discriminant Analysis", "Linear Discriminant Analysis"],
        "regression": ["Linear Regression", "Random Forest Regressor", "XGBoost Regressor" , "Support Vector Regressor", "Decision Tree Regressor", "K-Nearest Neighbors Regressor", "ElasticNet", "Gradient Boosting Regressor", "AdaBoost Regressor", "Bayesian Ridge" , "Ridge Regression", "Lasso Regression"],
    }

    # Initialize selected model if not already set or if task type changed
    if st.session_state.selected_model not in model_options[task_type]:
        st.session_state.selected_model = model_options[task_type][0]
    
    # Use session state to maintain selection across reruns
    selected_model_name = st.selectbox(
        "πŸ€– Choose Model", 
        model_options[task_type], 
        index=model_options[task_type].index(st.session_state.selected_model),
        key="selected_model_select"
    )
    
    # Update session state after selection
    st.session_state.selected_model = selected_model_name

    st.markdown("### πŸ”§ Hyperparameters")
    hyperparams = get_hyperparams_ui(selected_model_name)

    # Use a unique key for the button to avoid conflicts
    if st.button("πŸš€ Train Model", key="train_model_button_unique"):
        with st.spinner("Training in progress... ⏳"):
            trained_model = train_model(df, target_column, task_type, selected_model_name, hyperparams)
            st.success("βœ… Model trained successfully!")
            st.session_state.trained_model = trained_model
            st.session_state.model_trained = True
            
            # Note: test_results_calculated is already reset in train_model function

    if "trained_model" in st.session_state:
        st.markdown("### πŸ“₯ Download Trained Model")
        
        # Use a safer approach for file operations with proper cleanup
        try:
            # Use a temporary file that will be automatically cleaned up
            import tempfile
            with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file:
                pickle.dump(st.session_state.trained_model, temp_file)
                temp_file_path = temp_file.name
            
            # Read the file for download
            with open(temp_file_path, "rb") as f:
                st.download_button(
                    label="πŸ“₯ Download Model",
                    data=f,
                    file_name="trained_model.pkl",
                    mime="application/octet-stream",
                )
                
            # Clean up the temporary file
            import os
            try:
                os.unlink(temp_file_path)
            except:
                pass  # Silently handle deletion errors
                
        except Exception as e:
            st.error(f"Error preparing model for download: {str(e)}")