import sys import os import streamlit as st import pickle import pandas as pd import time import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates import warnings import torch from config_streamlit import (MODEL_PATH_LIGHTGBM, DATA_PATH, TRAIN_RATIO, PLOT_COLOR) from lightgbm_model.scripts.config_lightgbm import FEATURES from transformer_model.scripts.utils.informer_dataset_class import InformerDataset from transformer_model.scripts.training.load_basis_model import load_moment_model from transformer_model.scripts.config_transformer import CHECKPOINT_DIR, FORECAST_HORIZON, SEQ_LEN from sklearn.preprocessing import StandardScaler from huggingface_hub import hf_hub_download # ============================== Layout ============================== # Streamlit & warnings config warnings.filterwarnings("ignore", category=FutureWarning) st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide") #CSS part st.markdown(f""" """, unsafe_allow_html=True) st.title("Electricity Consumption Forecast: Hourly Simulation") st.write("Welcome to the simulation interface!") st.info( "**Simulation Overview:**\n\n" "This dashboard provides an hourly electricity consumption forecast using two different models: " "**LightGBM** and a **Transformer (moment-based)**. Both models generate a fresh prediction at every time step " "(i.e., every simulated hour).\n\n" "Note: Since this app runs on a limited CPU on Hugging Face Spaces, the Transformer model may respond slower " "compared to local execution. On a standard local CPU, performance is significantly better." ) # ============================== Session State Init ============================== def init_session_state(): defaults = { "is_running": False, "start_index": 0, "true_vals": [], "pred_vals": [], "true_timestamps": [], "pred_timestamps": [], "last_fig": None, "valid_pos": 0, "first_plot_shown": False } for key, value in defaults.items(): if key not in st.session_state: st.session_state[key] = value init_session_state() # ============================== Loaders ============================== CSV_PATH_HF = hf_hub_download( repo_id="dlaj/energy-forecasting-files", filename="data/processed/energy_consumption_aggregated_cleaned.csv", repo_type="dataset" ) @st.cache_resource def load_transformer_model_and_dataset(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_moment_model() checkpoint_path = hf_hub_download( repo_id="dlaj/energy-forecasting-files", filename="transformer_model/model_final.pth", repo_type="dataset" ) model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.to(device) model.eval() # Datasets train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13, csv_path=CSV_PATH_HF) test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13, csv_path=CSV_PATH_HF) test_dataset.scaler = train_dataset.scaler return model, test_dataset, device @st.cache_data def load_data(): df = pd.read_csv(CSV_PATH_HF, parse_dates=["date"]) return df #Load lightgbm model @st.cache_data def load_lightgbm_model(): with open(MODEL_PATH_LIGHTGBM, "rb") as f: return pickle.load(f) # ============================== Utility Functions ============================== def predict_transformer_step(model, dataset, idx, device): """Performs a single prediction step with the transformer model.""" timeseries, _, input_mask = dataset[idx] timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device) input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device) with torch.no_grad(): output = model(x_enc=timeseries, input_mask=input_mask) pred = output.forecast[:, 0, :].cpu().numpy().flatten() # Rückskalieren dummy = np.zeros((len(pred), dataset.n_channels)) dummy[:, 0] = pred pred_original = dataset.scaler.inverse_transform(dummy)[:, 0] return float(pred_original[0]) def init_simulation_layout(): col1, spacer, col2 = st.columns([3, 0.2, 1]) plot_title = col1.empty() plot_container = col1.empty() x_axis_label = col1.empty() info_container = col2.empty() return plot_title, plot_container, x_axis_label, info_container def create_prediction_plot(pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=None, y_max=None): """Generates the matplotlib figure for plotting prediction vs. actual.""" fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR) ax.set_facecolor(PLOT_COLOR) ax.plot(pred_timestamps[-window_hours:], pred_vals[-window_hours:], label="Prediction", color="#EF233C", linestyle="--") if true_vals: ax.plot(true_timestamps[-window_hours:], true_vals[-window_hours:], label="Actual", color="#0077B6") ax.set_ylabel("Consumption (MW)", fontsize=8) ax.legend( fontsize=8, loc="upper left", bbox_to_anchor=(0, 0.95) ) ax.yaxis.grid(True, linestyle=':', linewidth=0.5, alpha=0.7) ax.set_ylim(y_min, y_max) ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d")) ax.tick_params(axis="x", labelrotation=0, labelsize=5) ax.tick_params(axis="y", labelsize=5) for spine in ax.spines.values(): spine.set_visible(False) st.session_state.last_fig = fig return fig def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False): """Displays the simulation plot and metrics in the UI.""" title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction" plot_title.markdown( f"
" f"{title}
", unsafe_allow_html=True ) plot_container.pyplot(fig) with info_container.container(): #st.markdown("
", unsafe_allow_html=True) st.markdown( f"Time: {timestamp}", unsafe_allow_html=True ) st.metric("Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–") st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–") st.caption("Simulation Progress") st.progress(progress) if len(st.session_state.true_vals) > 1: true_arr = np.array(st.session_state.true_vals) pred_arr = np.array(st.session_state.pred_vals[:-1]) min_len = min(len(true_arr), len(pred_arr)) #just start if there are 2 actual values if min_len >= 1: errors = np.abs(true_arr[:min_len] - pred_arr[:min_len]) mape = np.mean(errors / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])) * 100 mae = np.mean(errors) max_error = np.max(errors) st.divider() st.markdown( f"Interim Metrics", unsafe_allow_html=True ) st.metric("MAPE (so far)", f"{mape:.2f} %") st.metric("MAE (so far)", f"{mae:,.0f} MW") st.metric("Max Error", f"{max_error:,.0f} MW") # ============================== Data Preparation ============================== df_full = load_data() # Split Train/Test train_size = int(len(df_full) * TRAIN_RATIO) test_df_raw = df_full.iloc[train_size:].reset_index(drop=True) # Start at first full hour (00:00) first_full_day_index = test_df_raw[test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time()].index[0] test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True) # Select simulation window via date picker min_date = test_df_full["date"].min().date() max_date = test_df_full["date"].max().date() # ============================== UI Controls ============================== with st.sidebar: st.header("⚙️ Simulation Settings") st.subheader("General Settings") model_choice = st.selectbox("Choose prediction model", ["LightGBM", "Transformer Model (moments)"]) if model_choice == "Transformer Model (moments)": st.caption("⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)") window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0) window_hours = window_days * 24 speed = st.slider("Speed", 1, 10, 5) st.subheader("Date Range") start_date = st.date_input("Start Date", value=min_date, min_value=min_date, max_value=max_date) end_date = st.date_input("End Date", value=max_date, min_value=min_date, max_value=max_date) # ============================== Data Preparation (filtered) ============================== # final filtered date window test_df_filtered = test_df_full[ (test_df_full["date"].dt.date >= start_date) & (test_df_full["date"].dt.date <= end_date) ].reset_index(drop=True) # For progression bar total_steps_ui = len(test_df_filtered) # ============================== Buttons ============================== st.markdown("### Start Simulation") col1, col2, col3 = st.columns([1, 1, 4]) with col1: play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause" if st.button(play_pause_text, use_container_width=True): st.session_state.is_running = not st.session_state.is_running st.rerun() with col2: reset_button = st.button("🔄 Reset", use_container_width=True) # Reset logic if reset_button: st.session_state.start_index = 0 st.session_state.pred_vals = [] st.session_state.true_vals = [] st.session_state.pred_timestamps = [] st.session_state.true_timestamps = [] st.session_state.last_fig = None st.session_state.is_running = False st.session_state.valid_pos = 0 st.session_state.first_plot_shown = False st.rerun() # Auto-reset on critical parameter change while running if st.session_state.is_running and ( start_date != st.session_state.get("last_start_date") or end_date != st.session_state.get("last_end_date") or model_choice != st.session_state.get("last_model_choice") ): st.session_state.start_index = 0 st.session_state.pred_vals = [] st.session_state.true_vals = [] st.session_state.pred_timestamps = [] st.session_state.true_timestamps = [] st.session_state.last_fig = None st.session_state.valid_pos = 0 st.session_state.first_plot_shown = False st.rerun() # Track current selections for change detection st.session_state.last_start_date = start_date st.session_state.last_end_date = end_date st.session_state.last_model_choice = model_choice # ============================== Paused Mode ============================== if not st.session_state.is_running and st.session_state.last_fig is not None: st.write("Simulation paused...") plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() timestamp = st.session_state.pred_timestamps[-1] if st.session_state.pred_timestamps else "–" prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None progress = st.session_state.start_index / total_steps_ui render_simulation_view(timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True) # ============================== initialize values ============================== #if lightGbm use testdata from above if model_choice == "LightGBM": test_df = test_df_filtered.copy() #Shared state references for storing predictions and ground truths true_vals = st.session_state.true_vals pred_vals = st.session_state.pred_vals true_timestamps = st.session_state.true_timestamps pred_timestamps = st.session_state.pred_timestamps # ============================== LightGBM Simulation ============================== if model_choice == "LightGBM" and st.session_state.is_running: model = load_lightgbm_model() st.write("Simulation started...") st.markdown('
', unsafe_allow_html=True) plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() for i in range(st.session_state.start_index, len(test_df)): if not st.session_state.is_running: break current = test_df.iloc[i] timestamp = current["date"] features = current[FEATURES].values.reshape(1, -1) prediction = model.predict(features)[0] pred_vals.append(prediction) pred_timestamps.append(timestamp) if i >= 1: prev_actual = test_df.iloc[i - 1]["consumption_MW"] prev_time = test_df.iloc[i - 1]["date"] true_vals.append(prev_actual) true_timestamps.append(prev_time) fig = create_prediction_plot( pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min= test_df_filtered["consumption_MW"].min() - 2000, y_max= test_df_filtered["consumption_MW"].max() + 2000 ) render_simulation_view(timestamp, prediction, prev_actual if i >= 1 else None, i / len(test_df), fig) plt.close(fig) # Speicher freigeben st.session_state.start_index = i + 1 time.sleep(1 / (speed + 1e-9)) st.success("Simulation completed!") # ============================== Transformer Simulation ============================== spinner_placeholder = st.empty() if model_choice == "Transformer Model (moments)": if st.session_state.is_running: st.write("Simulation started (Transformer)...") st.markdown('
', unsafe_allow_html=True) if not st.session_state.first_plot_shown: spinner_placeholder.markdown("Running first prediction – please wait...") plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() # Zugriff auf Modell, Dataset, Device model, test_dataset, device = load_transformer_model_and_dataset() data = test_dataset.data # bereits skaliert scaler = test_dataset.scaler n_channels = test_dataset.n_channels test_start_idx = len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, csv_path=CSV_PATH_HF)) + SEQ_LEN base_timestamp = pd.read_csv(CSV_PATH_HF, parse_dates=["date"])["date"].iloc[test_start_idx] #get original timestamp for later, cause not in dataset anymore # Schritt 1: Finde Index, ab dem Stunde = 00:00 ist offset = 0 while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp("00:00:00").time(): offset += 1 # Neuer Startindex in der Simulation start_index = offset # Session-State bei Bedarf initial setzen if "start_index" not in st.session_state or st.session_state.start_index == 0: st.session_state.start_index = start_index # Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum valid_indices = [] for i in range(start_index, len(test_dataset)): timestamp = base_timestamp + pd.Timedelta(hours=i) if start_date <= timestamp.date() <= end_date: valid_indices.append(i) # Fortschrittsanzeige total_steps = len(valid_indices) # Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!) if "valid_pos" not in st.session_state: st.session_state.valid_pos = 0 # Hauptschleife: Nur noch über gültige Indizes iterieren for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos:]): #for i in range(st.session_state.start_index, len(test_dataset)): if not st.session_state.is_running: break current_pred = predict_transformer_step(model, test_dataset, i, device) current_time = base_timestamp + pd.Timedelta(hours=i) pred_vals.append(current_pred) pred_timestamps.append(current_time) if i >= 1: prev_actual = test_dataset[i - 1][1][0, 0] # erster Forecast-Wert der letzten Zeile # Rückskalieren dummy_actual = np.zeros((1, n_channels)) dummy_actual[:, 0] = prev_actual actual_val = scaler.inverse_transform(dummy_actual)[0, 0] true_time = current_time - pd.Timedelta(hours=1) if true_time >= pd.to_datetime(start_date): true_vals.append(actual_val) true_timestamps.append(true_time) # Plot erzeugen fig = create_prediction_plot( pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min= test_df_filtered["consumption_MW"].min() - 2000, y_max= test_df_filtered["consumption_MW"].max() + 2000 ) if len(pred_vals) >= 2 and len(true_vals) >= 1: render_simulation_view(current_time, current_pred, actual_val if i >= 1 else None, st.session_state.valid_pos / total_steps, fig) if not st.session_state.first_plot_shown: spinner_placeholder.empty() st.session_state.first_plot_shown = True plt.close(fig) # Speicher freigeben st.session_state.valid_pos += 1 time.sleep(1 / (speed + 1e-9)) st.success("Simulation completed!") # ============================== Scroll Sync ============================== st.markdown(""" """, unsafe_allow_html=True)