dlaj commited on
Commit
8cc5633
·
1 Parent(s): 6baf912

Deploy from GitHub

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base="light"
3
+ primaryColor="#FF4B4B"
4
+ backgroundColor="#f8f9fa"
5
+ textColor="#004080"
6
+ secondaryBackgroundColor="#edf1f7"
7
+ font="sans serif"
8
+
9
+
README.md ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Energy Forecasting with Transformer and LightGBM
2
+
3
+ This project focuses on forecasting urban energy consumption based solely on historical usage and temperature data from Chicago (2011–2018). Two model architectures are compared: a LightGBM ensemble model and a Transformer-based neural network (based on the Moments Time Series Transformer). The goal is to predict hourly electricity demand and analyze model performance, interpretability, and generalizability.
4
+
5
+ The project also simulates a real-time setting, where hourly predictions are made sequentially to mirror operational deployment. The modular design allows for adaptation to other urban contexts, assuming a compatible data structure.
6
+
7
+ ---
8
+
9
+ ## Overview
10
+
11
+ * **Goal**: Predict hourly energy consumption using timestamp, temperature, and historical consumption features.
12
+ * **Models**: LightGBM and Time Series Transformer Model (moements).
13
+ * **Results**: Both models perform well; LightGBM achieves the best overall performance.
14
+ * **Dashboard**: Live forecast simulation via Streamlit interface.
15
+ * **Usage Context**: Developed as a prototype for real-time hourly forecasting, with a modular structure that supports adaptation to similar operational settings.
16
+
17
+ ---
18
+
19
+ ## Results
20
+
21
+ ### Evaluation Metrics
22
+
23
+ | Model | RMSE | R² | MAPE |
24
+ | ----------- | ------- | ----- | ------ |
25
+ | Transformer | 3933.57 | 0.972 | 2.32 % |
26
+ | LightGBM | 1383.68 | 0.996 | 0.84 % |
27
+
28
+ > **Note:** All values are in megawatts (MW). Hourly consumption typically ranges from 100,000 to 200,000 MW.
29
+
30
+ * LightGBM achieves the best trade-off between performance and resource efficiency.
31
+ * The Transformer model generalizes well to temporal patterns and may scale better in more complex or multi-network scenarios.
32
+ * Both models show no signs of overfitting, supported by learning curves, consistent evaluation metrics, and additional diagnostics such as residual distribution analysis and noise-feature validation.
33
+
34
+ ---
35
+
36
+ ### Forecast Plots
37
+
38
+ | LightGBM Prediction Plot | Transformer Prediction Plot |
39
+ | :----------------------: | :--------------------------: |
40
+ | ![LightGBM Prediction](assets/lightgbm_prediction_with_timestamp.png) | ![Transformer Prediction](assets/comparison_plot_1month.png) |
41
+
42
+ > **Note:** Example forecast windows are shown (LightGBM: 3 months, Transformer: 1 month).
43
+ > LightGBM maintains highly consistent performance over time, while the Transformer shows occasional over- or underestimation on special peak days.
44
+
45
+ ---
46
+
47
+ ### Learning Curves
48
+
49
+ These plots visualize training dynamics and help detect overfitting.
50
+
51
+ | LightGBM Learning Curve | Transformer Learning Curve |
52
+ | :----------------------: | :------------------------: |
53
+ | ![LightGBM LC](assets/lightgbm_learning_curve.png) | ![Transformer LC](assets/training_plot.png) |
54
+
55
+ * The LightGBM curve shows a stable gap between training and validation RMSE, indicating low overfitting.
56
+ * The Transformer learning curve also converges smoothly without divergence, supporting generalizability.
57
+ * In addition to visual inspection, further checks like residual analysis and a noise feature test confirmed robustness.
58
+
59
+ > **Note:** The LightGBM curve shows boosting rounds with validation RMSE,
60
+ > while the Transformer plot tracks training loss and test metrics per epoch.
61
+
62
+ More plots are available in the respective `/results` directories.
63
+
64
+ ---
65
+
66
+ ## Streamlit Simulation Dashboard
67
+
68
+ * Live hourly forecast simulation
69
+ * Uses the trained models
70
+ * Repeats predictions sequentially for each hour to simulate real-time data flow
71
+ * Hosted on Hugging Face (CPU only, slower prediction speed)
72
+
73
+ You can try the model predictions interactively in the Streamlit dashboard:
74
+
75
+ **Try it here:**
76
+ **[Launch Streamlit App](https://huggingface.co/spaces/dlaj/energy-forecasting-demo)**
77
+
78
+ **Preview:**
79
+
80
+ ![Streamlit Dashboard Preview](assets/streamlit_preview.gif)
81
+
82
+ ---
83
+
84
+ ## Data
85
+
86
+ * **Source**:
87
+
88
+ * [COMED Hourly Consumption Data](https://www.kaggle.com/datasets/robikscube/hourly-energy-consumption)
89
+ * [NOAA Temperature Data](https://www.ncei.noaa.gov/)
90
+ * **Time range**: January 2011 – August 2018
91
+ * **Merged file**: `data/processed/energy_consumption_aggregated_cleaned.csv`
92
+
93
+ ---
94
+
95
+ ## Feature Engineering
96
+
97
+ The models rely on timestamp and temperature data, enriched with derived time-based and lag-based features:
98
+
99
+ * hour\_sin, hour\_cos
100
+ * weekday\_sin, weekday\_cos
101
+ * month\_sin, month\_cos
102
+ * rolling\_mean\_6h
103
+ * temperature\_c
104
+ * consumption\_last\_hour
105
+ * consumption\_yesterday
106
+ * consumption\_last\_week
107
+
108
+ Feature selection was guided by LightGBM feature importance analysis. Weak features with nearly no impact like "is_weekend" were deleted.
109
+
110
+ ### Final LightGBM Feature Importance
111
+
112
+ <img src="assets/lightgbm_feature_importance.png" alt="Feature Importance" style="width: 80%;"/>
113
+
114
+ ---
115
+
116
+ ## Model Development
117
+
118
+ ### LightGBM
119
+
120
+ * Custom grid search with over 50 parameter combinations
121
+ * Parameters tested:
122
+
123
+ * num\_leaves, max\_depth, learning\_rate, lambda\_l1, lambda\_l2, min\_split\_gain
124
+ * Final Parameters:
125
+
126
+ * learning\_rate: 0.05
127
+ * num\_leaves: 15
128
+ * max\_depth: 5
129
+ * lambda\_l1: 1.0
130
+ * lambda\_l2: 0.0
131
+ * min\_split\_gain: 0.0
132
+ * n\_estimators: 1000
133
+ * objective: regression
134
+
135
+ Overfitting was monitored using a noise feature and RMSE gaps. See grid search results:
136
+ `notebooks/lightgbm/lightgbm_gridsearch_results.csv`
137
+
138
+ ### Transformer (Moments Time Series Transformer)
139
+
140
+ * Based on pretrained Moments model
141
+ * Fine-tuned only the forecasting head for regular training
142
+ * Also tested variants with unfrozen encoder layers and dropout
143
+ * Final config:
144
+
145
+ * task\_name: forecasting
146
+ * forecast\_horizon: 24
147
+ * head\_dropout: 0.1
148
+ * weight\_decay: 0
149
+ * freeze\_encoder: True
150
+ * freeze\_embedder: True
151
+ * freeze\_head: False
152
+
153
+ ---
154
+
155
+ ## Project Structure
156
+
157
+ ```
158
+ energy-forecasting-transformer-lightgbm/
159
+ ├── data/ # Raw, external, processed datasets
160
+ ├── notebooks/ # EDA, lightgbm and transformer prototypes, including hyperparameter tuning and model selection
161
+ ├── scripts/ # Data preprocessing scripts
162
+ ├── lightgbm_model/ # LightGBM model, scripts, results
163
+ ├── transformer_model/ # Transformer model, scripts, results
164
+ ├── streamlit_simulation/ # Streamlit dashboard
165
+ ├── requirements.txt # Main environment
166
+ ├── requirements_lgbm.txt # Optional for LightGBM
167
+ ├── setup.py
168
+ └── README.md
169
+ ```
170
+
171
+ ---
172
+
173
+ ## Reproducibility
174
+
175
+ You can reuse this pipeline with any dataset, as long as it contains the following key columns:
176
+
177
+ ```csv
178
+ timestamp, # hourly timestamp (e.g., "2018-01-01 14:00")
179
+ consumption, # energy usage (aggregated; for individual users, consider adding an ID column)
180
+ temperature # hourly
181
+ ```
182
+
183
+ ### Notes:
184
+
185
+ * Transformer model training is **very slow on CPU**, also with AMD GPU
186
+ * Recommended: use **CUDA or Google Colab + CUDA GPU runtime** for transformer training
187
+ * All scripts are modular and can be executed separately
188
+
189
+ ---
190
+
191
+ ## CI/CD & DevOps Setup
192
+
193
+ This project includes a lightweight CI pipeline using GitHub Actions:
194
+
195
+ * **CI**:
196
+ - Runs `pytest` on every push
197
+ - Builds and validates the Docker image
198
+
199
+ * **Code quality checks**:
200
+ - Uses `pre-commit` hooks with `black`, `isort`, and `ruff`
201
+ - Ensures consistent formatting and linting before commits
202
+
203
+ To enable pre-commit locally:
204
+
205
+ ```bash
206
+ pre-commit install
207
+ ```
208
+
209
+ ---
210
+
211
+ ## Run Locally
212
+
213
+ ### Prerequisites
214
+
215
+ * Python 3.9–3.11 (required for Moments Transformer)
216
+
217
+ ### Installation
218
+
219
+ ```bash
220
+ git clone https://github.com/dlajic/energy-forecasting-transformer-lightgbm.git
221
+ cd energy-forecasting-transformer-lightgbm
222
+ pip install -r requirements.txt
223
+ ```
224
+
225
+ ### Preprocess Data
226
+
227
+ ```bash
228
+ python -m scripts.data_preprocessing.merge_temperature_data # merges raw temperature and energy data (only needed with raw inputs)
229
+ python -m scripts.data_preprocessing.preprocess_data # launches full preprocessing pipeline; use if data already matches expected format
230
+ ```
231
+
232
+ ### Train Models
233
+
234
+ ```bash
235
+ python -m lightgbm_model.scripts.train.train_lightgbm
236
+ python -m transformer_model.scripts.training.train
237
+ ```
238
+
239
+ ### Evaluate Models
240
+
241
+ ```bash
242
+ python -m lightgbm_model.scripts.eval.eval_lightgbm
243
+ python -m transformer_model.scripts.evaluation.evaluate
244
+ python -m transformer_model.scripts.evaluation.plot_learning_curves
245
+ ```
246
+
247
+ ### Run Streamlit Dashboard (local)
248
+
249
+ ```bash
250
+ streamlit run streamlit_simulation/app.py
251
+ ```
252
+
253
+ For editable install:
254
+
255
+ ```bash
256
+ pip install -e .
257
+ ```
258
+
259
+ ## Run App with Docker
260
+
261
+ This project also supports containerized execution using Docker:
262
+
263
+
264
+ ```bash
265
+ # Start app with Docker Compose (Linux)
266
+ ./start.sh
267
+
268
+ # Or on Windows (PowerShell)
269
+ ./start.ps1
270
+ ```
271
+
272
+ Make sure Docker (Docker-Desktop) is running before executing the script.
273
+
274
+ This will:
275
+
276
+ 1. Build the Docker image
277
+ 2. Start the Streamlit app on localhost:8501
278
+ 3. Open it automatically in your browser
279
+
280
+ ---
281
+
282
+ ## Author
283
+
284
+ Dean Lajic
285
+ GitHub: [dlajic](https://github.com/dlajic)
286
+
287
+ ---
288
+
289
+ ## References
290
+
291
+ - Moments Time Series Transformer
292
+ https://github.com/moment-timeseries-foundation-model/moment
293
+ - COMED Consumption Dataset
294
+ https://www.kaggle.com/datasets/robikscube/hourly-energy-consumption
295
+ - NOAA Weather Data
296
+ https://www.ncei.noaa.gov
lightgbm_model/scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__.py
lightgbm_model/scripts/config_lightgbm.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+
4
+ # === Paths ===
5
+ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
+ DATA_PATH = os.path.join(
7
+ BASE_DIR, "..", "data", "processed", "energy_consumption_aggregated_cleaned.csv"
8
+ )
9
+ RESULTS_DIR = os.path.join(BASE_DIR, "results")
10
+ MODEL_DIR = os.path.join(BASE_DIR, "model")
11
+
12
+ # === Feature-Definition ===
13
+ FEATURES = [
14
+ "hour_sin",
15
+ "hour_cos",
16
+ "weekday_sin",
17
+ "weekday_cos",
18
+ "rolling_mean_6h",
19
+ "month_sin",
20
+ "month_cos",
21
+ "temperature_c",
22
+ "consumption_last_week",
23
+ "consumption_yesterday",
24
+ "consumption_last_hour",
25
+ ]
26
+ TARGET = "consumption_MW"
27
+
28
+ # === Hyperparameters fpr LightGBM ===
29
+ LIGHTGBM_PARAMS = {
30
+ "learning_rate": 0.05,
31
+ "num_leaves": 15,
32
+ "max_depth": 5,
33
+ "lambda_l1": 1.0,
34
+ "lambda_l2": 0.0,
35
+ "min_split_gain": 0.0,
36
+ "n_estimators": 1000,
37
+ "objective": "regression",
38
+ }
39
+
40
+ # === Early Stopping ===
41
+ EARLY_STOPPING_ROUNDS = 50
lightgbm_model/scripts/eval/eval_lightgbm.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # eval_model.py
2
+
3
+ import json
4
+ import os
5
+ import pickle
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
11
+
12
+ from lightgbm_model.scripts.config_lightgbm import DATA_PATH, RESULTS_DIR
13
+ from lightgbm_model.scripts.utils import load_lightgbm_model
14
+
15
+ # === Ergebnisse-Ordner vorbereiten ===
16
+ os.makedirs(RESULTS_DIR, exist_ok=True)
17
+
18
+ # === Modell und eval_result laden ===
19
+ # Modell laden
20
+ model = load_lightgbm_model()
21
+
22
+ # Eval laden
23
+ with open(os.path.join(RESULTS_DIR, "lightgbm_eval_result.pkl"), "rb") as f:
24
+ eval_result = pickle.load(f)
25
+ X_train = pd.read_csv(os.path.join(RESULTS_DIR, "X_train.csv"))
26
+ X_test = pd.read_csv(os.path.join(RESULTS_DIR, "X_test.csv"))
27
+ y_test = pd.read_csv(os.path.join(RESULTS_DIR, "y_test.csv"))
28
+
29
+ # === Lernkurve ===
30
+ train_rmse = eval_result["training"]["rmse"]
31
+ valid_rmse = eval_result["valid_1"]["rmse"]
32
+
33
+ plt.figure(figsize=(10, 5))
34
+ plt.plot(train_rmse, label="Train RMSE")
35
+ plt.plot(valid_rmse, label="Valid RMSE")
36
+ plt.axvline(model.best_iteration_, color="gray", linestyle="--", label="Best Iteration")
37
+ plt.xlabel("Boosting Round")
38
+ plt.ylabel("RMSE")
39
+ plt.title("LightGBM Learning Curve")
40
+ plt.legend()
41
+ plt.tight_layout()
42
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_learning_curve.png"))
43
+ # plt.show()
44
+
45
+ # === Metriken berechnen ===
46
+ y_pred = model.predict(X_test)
47
+ mae = mean_absolute_error(y_test, y_pred)
48
+ rmse = np.sqrt(mean_squared_error(y_test, y_pred))
49
+ mape = (
50
+ np.mean(
51
+ np.abs(
52
+ (y_test.values.flatten() - y_pred)
53
+ / np.where(y_test.values.flatten() == 0, 1e-10, y_test.values.flatten())
54
+ )
55
+ )
56
+ * 100
57
+ )
58
+ r2 = r2_score(y_test, y_pred)
59
+
60
+ print(f"Test MAPE: {mape:.5f} %")
61
+ print(f"Test MAE: {mae:.5f}")
62
+ print(f"Test RMSE: {rmse:.5f}")
63
+ print(f"Test R2: {r2:.5f}")
64
+
65
+ metrics = {
66
+ "model": "LightGBM",
67
+ "MAE": round(mae, 2),
68
+ "RMSE": round(rmse, 2),
69
+ "MAPE (%)": round(mape, 2),
70
+ "R2": round(r2, 4),
71
+ "unit": "MW",
72
+ }
73
+
74
+ # Pfad setzen
75
+ output_path = os.path.join(RESULTS_DIR, "evaluation_metrics_lightgbm.json")
76
+ # Speichern
77
+ with open(output_path, "w") as f:
78
+ json.dump(metrics, f, indent=4)
79
+
80
+ print(f"Metriken gespeichert unter {output_path}")
81
+
82
+ # === Feature Importance ===
83
+ feature_importance = pd.DataFrame(
84
+ {"Feature": X_train.columns, "Importance": model.feature_importances_}
85
+ ).sort_values(by="Importance", ascending=False)
86
+
87
+ plt.figure(figsize=(10, 6))
88
+ plt.barh(feature_importance["Feature"], feature_importance["Importance"])
89
+ plt.xlabel("Feature Importance")
90
+ plt.title("LightGBM Feature Importance")
91
+ plt.gca().invert_yaxis()
92
+ plt.tight_layout()
93
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_feature_importance.png"))
94
+ # plt.show()
95
+
96
+ # === Vergleichsplots ===
97
+ results_df = pd.DataFrame(
98
+ {
99
+ "True Consumption (MW)": y_test.values.flatten(),
100
+ "Predicted Consumption (MW)": y_pred,
101
+ }
102
+ )
103
+
104
+ # Timestamps anhängen
105
+ full_df = pd.read_csv(DATA_PATH)
106
+ test_dates = full_df.iloc[int(len(full_df) * 0.8) :]["date"].reset_index(drop=True)
107
+ results_df["Timestamp"] = pd.to_datetime(test_dates)
108
+
109
+ # Voller Plot
110
+ plt.figure(figsize=(15, 6))
111
+ plt.plot(
112
+ results_df["Timestamp"],
113
+ results_df["True Consumption (MW)"],
114
+ label="True",
115
+ color="darkblue",
116
+ )
117
+ plt.plot(
118
+ results_df["Timestamp"],
119
+ results_df["Predicted Consumption (MW)"],
120
+ label="Predicted",
121
+ color="red",
122
+ linestyle="--",
123
+ )
124
+ plt.title("Predicted vs True Consumption")
125
+ plt.xlabel("Timestamp")
126
+ plt.ylabel("Consumption (MW)")
127
+ plt.legend()
128
+ plt.tight_layout()
129
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_comparison_plot.png"))
130
+ # plt.show()
131
+
132
+ # Subset Plot
133
+ subset = results_df.iloc[: len(results_df) // 10]
134
+ plt.figure(figsize=(15, 6))
135
+ plt.plot(
136
+ subset["Timestamp"], subset["True Consumption (MW)"], label="True", color="darkblue"
137
+ )
138
+ plt.plot(
139
+ subset["Timestamp"],
140
+ subset["Predicted Consumption (MW)"],
141
+ label="Predicted",
142
+ color="red",
143
+ linestyle="--",
144
+ )
145
+ plt.title("Predicted vs True (First decile)")
146
+ plt.xlabel("Timestamp")
147
+ plt.ylabel("Consumption (MW)")
148
+ plt.legend()
149
+ plt.tight_layout()
150
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_prediction_with_timestamp.png"))
151
+ # plt.show()
152
+
153
+
154
+ # === Ens message ===
155
+ print("\nEvaluation completed.")
156
+ print(f"All Plots stored in:\n→ {RESULTS_DIR}")
lightgbm_model/scripts/model_loader_wrapper.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightgbm_model.scripts.utils import load_lightgbm_model as real_model
2
+ from streamlit_simulation.utils.env import use_dummy
3
+
4
+
5
+ def load_lightgbm_model():
6
+ if use_dummy():
7
+ from streamlit_simulation.utils.dummy import DummyLightGBMModel
8
+
9
+ return DummyLightGBMModel()
10
+ else:
11
+ return real_model()
lightgbm_model/scripts/train/train_lightgbm.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_lightgbm.py
2
+
3
+ import os
4
+ import pickle
5
+
6
+ import pandas as pd
7
+ from lightgbm import LGBMRegressor, early_stopping, record_evaluation
8
+
9
+ from lightgbm_model.scripts.config_lightgbm import (DATA_PATH,
10
+ EARLY_STOPPING_ROUNDS,
11
+ FEATURES, LIGHTGBM_PARAMS,
12
+ MODEL_DIR, RESULTS_DIR,
13
+ TARGET)
14
+
15
+ # === Load Data ===
16
+ df = pd.read_csv(DATA_PATH)
17
+
18
+ # Drop date (used later for plots only)
19
+ df = df.drop(columns=["date"], errors="ignore")
20
+
21
+ # === Time-based Split (70% train, 10% valid, 20% test) ===
22
+ train_size = int(len(df) * 0.7)
23
+ valid_size = int(len(df) * 0.1)
24
+ df_train = df.iloc[:train_size]
25
+ df_valid = df.iloc[train_size : train_size + valid_size]
26
+ df_test = df.iloc[train_size + valid_size :]
27
+
28
+ X_train, y_train = df_train[FEATURES], df_train[TARGET]
29
+ X_valid, y_valid = df_valid[FEATURES], df_valid[TARGET]
30
+ X_test, y_test = df_test[FEATURES], df_test[TARGET]
31
+
32
+
33
+ # === Init LightGBM model ===
34
+ eval_result = {}
35
+
36
+ model = LGBMRegressor(**LIGHTGBM_PARAMS, verbosity=-1)
37
+
38
+ model.fit(
39
+ X_train,
40
+ y_train,
41
+ eval_set=[(X_train, y_train), (X_valid, y_valid)],
42
+ eval_metric="rmse",
43
+ callbacks=[early_stopping(EARLY_STOPPING_ROUNDS), record_evaluation(eval_result)],
44
+ )
45
+
46
+ # === Save model ===
47
+ os.makedirs(MODEL_DIR, exist_ok=True)
48
+ model_path = os.path.join(MODEL_DIR, "lightgbm_final_model.pkl")
49
+
50
+ with open(model_path, "wb") as f:
51
+ pickle.dump(model, f)
52
+
53
+ # === Save evaluation results ===
54
+ os.makedirs(RESULTS_DIR, exist_ok=True)
55
+ eval_result_path = os.path.join(RESULTS_DIR, "lightgbm_eval_result.pkl")
56
+
57
+ with open(eval_result_path, "wb") as f:
58
+ pickle.dump(eval_result, f)
59
+
60
+ print(f"Model saved to: {model_path}")
61
+ print(f"Eval results saved to: {eval_result_path}")
62
+
63
+ # === Save data for evaluation ===
64
+ X_train.to_csv(os.path.join(RESULTS_DIR, "X_train.csv"), index=False)
65
+ X_test.to_csv(os.path.join(RESULTS_DIR, "X_test.csv"), index=False)
66
+ y_test.to_csv(os.path.join(RESULTS_DIR, "y_test.csv"), index=False)
lightgbm_model/scripts/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ MODEL_PATH = os.path.join("lightgbm_model", "model", "lightgbm_final_model.pkl")
5
+
6
+
7
+ def load_lightgbm_model():
8
+ with open(MODEL_PATH, "rb") as f:
9
+ return pickle.load(f)
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================
2
+ # Requirements for Energy Prediction Project
3
+ # =============================
4
+
5
+ # Python 3.11 environment recommended since moments dont work with later versions
6
+
7
+ # Moment Foundation Model (forecasting backbone)
8
+ momentfm @ git+https://github.com/moment-timeseries-foundation-model/moment.git@37a8bde4eb3dd340bebc9b54a3b893bcba62cd4f
9
+
10
+ # === Core Python stack ===
11
+ numpy==1.25.2 # Numerical operations
12
+ pandas==2.2.2 # Data manipulation and analysis
13
+ matplotlib==3.10.0 # Plotting and visualizations
14
+
15
+
16
+ # === Machine Learning ===
17
+ scikit-learn==1.6.1 # Evaluation metrics and preprocessing utilities
18
+ torch==2.6.0 # PyTorch with CUDA 12.4 (GPU support)
19
+ #torchvision==0.21.0 # Optional (can support visual tasks, not critical here)
20
+ #torchaudio==2.6.0 # Optional (comes with torch install, can stay)
21
+
22
+ # === Utilities ===
23
+ tqdm==4.67.1 # Progress bars
24
+ ipywidgets>=8.0 # Enables tqdm progress bars in Jupyter/Colab
25
+ pprintpp==0.4.0 # Prettier print formatting for nested dicts (used for model output check)
26
+
27
+ # === lightgbm ===
28
+ lightgbm==4.3.0 # Boosted Trees for tabular modeling (used for baseline and feature selection)
29
+
30
+ # === Streamlit App ===
31
+ streamlit>=1.30.0
32
+ plotly>=5.0.0
33
+
34
+ # === for pytest/env dummy/pre-commit/huggingface ====
35
+ pytest
36
+ python-dotenv
37
+ pre-commit
38
+ huggingface_hub
setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name="energy_prediction",
5
+ version="0.1",
6
+ packages=find_packages(),
7
+ )
streamlit_simulation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__.py
streamlit_simulation/app.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import warnings
3
+
4
+ import matplotlib.dates as mdates
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import pandas as pd
8
+ import streamlit as st
9
+ import torch
10
+ from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
11
+
12
+ from lightgbm_model.scripts.config_lightgbm import FEATURES
13
+ from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
14
+ from streamlit_simulation.utils_streamlit import load_data as load_data_raw
15
+ from transformer_model.scripts.config_transformer import (FORECAST_HORIZON,
16
+ SEQ_LEN)
17
+ from transformer_model.scripts.utils.informer_dataset_class import \
18
+ InformerDataset
19
+ from transformer_model.scripts.utils.model_loader_wrapper import \
20
+ load_model_and_dataset
21
+
22
+ # ============================== Layout ==============================
23
+
24
+ # Streamlit & warnings config
25
+ warnings.filterwarnings("ignore", category=FutureWarning)
26
+ st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide")
27
+
28
+ # CSS part
29
+ st.markdown(
30
+ f"""
31
+ <style>
32
+ .stButton > button {{
33
+ background-color: {PLOT_COLOR};
34
+ }}
35
+
36
+ /* Entfernt auch den leeren Platz über der App */
37
+ header[data-testid="stHeader"] {{
38
+ display: none !important;
39
+ height: 0px !important;
40
+ visibility: hidden !important;
41
+ }}
42
+
43
+ .block-container {{
44
+ padding-top: 0.5rem !important;
45
+ }}
46
+
47
+ </style>
48
+ """,
49
+ unsafe_allow_html=True,
50
+ )
51
+
52
+
53
+ st.title("Electricity Consumption Forecast: Hourly Simulation")
54
+ st.write("Welcome to the simulation interface!")
55
+ st.info(
56
+ "**Simulation Overview:**\n\n"
57
+ "This dashboard provides an hourly electricity consumption forecast using two different models: "
58
+ "**LightGBM** and a **Transformer (moment-based)**. Both models generate a fresh prediction at every time step "
59
+ "(i.e., every simulated hour).\n\n"
60
+ "Note: Since this app runs on a limited CPU on Hugging Face Spaces, the Transformer model may respond slower "
61
+ "compared to local execution. On a standard local CPU, performance is significantly better."
62
+ )
63
+
64
+
65
+ # ============================== Session State Init ===============================
66
+ def init_session_state():
67
+ defaults = {
68
+ "is_running": False,
69
+ "start_index": 0,
70
+ "true_vals": [],
71
+ "pred_vals": [],
72
+ "true_timestamps": [],
73
+ "pred_timestamps": [],
74
+ "last_fig": None,
75
+ "valid_pos": 0,
76
+ "first_plot_shown": False,
77
+ }
78
+ for key, value in defaults.items():
79
+ if key not in st.session_state:
80
+ st.session_state[key] = value
81
+
82
+
83
+ init_session_state()
84
+
85
+
86
+ # ============================== Loaders Cache ==============================
87
+ @st.cache_data
88
+ def load_cached_lightgbm_model():
89
+ return load_lightgbm_model()
90
+
91
+
92
+ @st.cache_resource
93
+ def load_transformer_model_and_dataset():
94
+ return load_model_and_dataset()
95
+
96
+
97
+ @st.cache_data
98
+ def load_data():
99
+ return load_data_raw()
100
+
101
+
102
+ # ============================== Utility Functions ==============================
103
+
104
+
105
+ def predict_transformer_step(model, dataset, idx, device):
106
+ """Performs a single prediction step with the transformer model."""
107
+ timeseries, _, input_mask = dataset[idx]
108
+ timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device)
109
+ input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device)
110
+
111
+ with torch.no_grad():
112
+ output = model(x_enc=timeseries, input_mask=input_mask)
113
+
114
+ pred = output.forecast[:, 0, :].cpu().numpy().flatten()
115
+
116
+ # Rückskalieren
117
+ dummy = np.zeros((len(pred), dataset.n_channels))
118
+ dummy[:, 0] = pred
119
+ pred_original = dataset.scaler.inverse_transform(dummy)[:, 0]
120
+
121
+ return float(pred_original[0])
122
+
123
+
124
+ def init_simulation_layout():
125
+ """Creates layout containers for plot and info sections."""
126
+ col1, spacer, col2 = st.columns([3, 0.2, 1])
127
+ plot_title = col1.empty()
128
+ plot_container = col1.empty()
129
+ x_axis_label = col1.empty()
130
+ info_container = col2.empty()
131
+ return plot_title, plot_container, x_axis_label, info_container
132
+
133
+
134
+ def create_prediction_plot(
135
+ pred_timestamps,
136
+ pred_vals,
137
+ true_timestamps,
138
+ true_vals,
139
+ window_hours,
140
+ y_min=None,
141
+ y_max=None,
142
+ ):
143
+ """Generates the matplotlib figure for plotting prediction vs. actual."""
144
+ fig, ax = plt.subplots(
145
+ figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR
146
+ )
147
+ ax.set_facecolor(PLOT_COLOR)
148
+
149
+ ax.plot(
150
+ pred_timestamps[-window_hours:],
151
+ pred_vals[-window_hours:],
152
+ label="Prediction",
153
+ color="#EF233C",
154
+ linestyle="--",
155
+ )
156
+ if true_vals:
157
+ ax.plot(
158
+ true_timestamps[-window_hours:],
159
+ true_vals[-window_hours:],
160
+ label="Actual",
161
+ color="#0077B6",
162
+ )
163
+
164
+ ax.set_ylabel("Consumption (MW)", fontsize=8)
165
+ ax.legend(
166
+ fontsize=8,
167
+ loc="upper left",
168
+ bbox_to_anchor=(0, 0.95),
169
+ # facecolor= INPUT_BG, # INPUT_BG
170
+ # edgecolor= ACCENT_COLOR, # ACCENT_COLOR
171
+ # labelcolor= TEXT_COLOR # TEXT_COLOR
172
+ )
173
+ ax.yaxis.grid(True, linestyle=":", linewidth=0.5, alpha=0.7)
174
+ ax.set_ylim(y_min, y_max)
175
+ ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))
176
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d"))
177
+ ax.tick_params(axis="x", labelrotation=0, labelsize=5)
178
+ ax.tick_params(axis="y", labelsize=5)
179
+ # fig.patch.set_facecolor('#e6ecf0') # outer area
180
+
181
+ for spine in ax.spines.values():
182
+ spine.set_visible(False)
183
+
184
+ st.session_state.last_fig = fig
185
+ return fig
186
+
187
+
188
+ def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False):
189
+ """Displays the simulation plot and metrics in the UI."""
190
+ title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction"
191
+ plot_title.markdown(
192
+ f"<div style='text-align: center; font-size: 20pt; font-weight: bold; margin-bottom: -0.7rem; margin-top: 0rem;'>"
193
+ f"{title}</div>",
194
+ unsafe_allow_html=True,
195
+ )
196
+ plot_container.pyplot(fig)
197
+
198
+ # st.markdown("<div style='margin-bottom: 0.5rem;'></div>", unsafe_allow_html=True)
199
+ # x_axis_label.markdown(f"<div style='text-align: center; font-size: 13pt; color: {TEXT_COLOR}; margin-top: -0.5rem;'>"f"Time</div>",unsafe_allow_html=True)
200
+
201
+ with info_container.container():
202
+ st.markdown(
203
+ f"<span style='font-size: 24px; font-weight: 600;'>Time: {timestamp}</span>",
204
+ unsafe_allow_html=True,
205
+ )
206
+ st.metric(
207
+ "Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–"
208
+ )
209
+ st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–")
210
+ st.caption("Simulation Progress")
211
+ st.progress(progress)
212
+
213
+ if len(st.session_state.true_vals) > 1:
214
+ true_arr = np.array(st.session_state.true_vals)
215
+ pred_arr = np.array(st.session_state.pred_vals[:-1])
216
+ min_len = min(len(true_arr), len(pred_arr))
217
+ if min_len >= 1:
218
+ errors = np.abs(true_arr[:min_len] - pred_arr[:min_len])
219
+ mape = (
220
+ np.mean(
221
+ errors
222
+ / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])
223
+ )
224
+ * 100
225
+ )
226
+ mae = np.mean(errors)
227
+ max_error = np.max(errors)
228
+
229
+ st.divider()
230
+ st.markdown(
231
+ "<span style='font-size: 24px; font-weight: 600; '>Interim Metrics</span>",
232
+ unsafe_allow_html=True,
233
+ )
234
+ st.metric("MAPE (so far)", f"{mape:.2f} %")
235
+ st.metric("MAE (so far)", f"{mae:,.0f} MW")
236
+ st.metric("Max Error", f"{max_error:,.0f} MW")
237
+
238
+
239
+ # ============================== Data Preparation ==============================
240
+
241
+ df_full = load_data()
242
+
243
+ # Split Train/Test
244
+ train_size = int(len(df_full) * TRAIN_RATIO)
245
+ test_df_raw = df_full.iloc[train_size:].reset_index(drop=True)
246
+
247
+ # Start at first full hour (00:00)
248
+ first_full_day_index = test_df_raw[
249
+ test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time()
250
+ ].index[0]
251
+ test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True)
252
+
253
+ # Select simulation window via date picker
254
+ min_date = test_df_full["date"].min().date()
255
+ max_date = test_df_full["date"].max().date()
256
+
257
+ # ============================== UI Controls ==============================
258
+
259
+ with st.sidebar:
260
+ st.header("⚙️ Simulation Settings")
261
+
262
+ st.subheader("General Settings")
263
+ model_choice = st.selectbox(
264
+ "Choose prediction model", ["LightGBM", "Transformer Model (moments)"]
265
+ )
266
+ if model_choice == "Transformer Model (moments)":
267
+ st.caption(
268
+ "⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)"
269
+ )
270
+ window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0)
271
+ window_hours = window_days * 24
272
+ speed = st.slider("Speed", 1, 10, 5)
273
+
274
+ st.subheader("Date Range")
275
+ start_date = st.date_input(
276
+ "Start Date", value=min_date, min_value=min_date, max_value=max_date
277
+ )
278
+ end_date = st.date_input(
279
+ "End Date", value=max_date, min_value=min_date, max_value=max_date
280
+ )
281
+
282
+ # ============================== Data Preparation (filtered) ==============================
283
+
284
+ # final filtered date window
285
+ test_df_filtered = test_df_full[
286
+ (test_df_full["date"].dt.date >= start_date)
287
+ & (test_df_full["date"].dt.date <= end_date)
288
+ ].reset_index(drop=True)
289
+
290
+ # For progression bar
291
+ total_steps_ui = len(test_df_filtered)
292
+
293
+ # ============================== Buttons ==============================
294
+
295
+ st.markdown("### Start Simulation")
296
+ col1, col2, col3 = st.columns([1, 1, 4])
297
+ with col1:
298
+ play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause"
299
+ if st.button(play_pause_text, use_container_width=True):
300
+ st.session_state.is_running = not st.session_state.is_running
301
+ st.rerun()
302
+ with col2:
303
+ reset_button = st.button("🔄 Reset", use_container_width=True)
304
+
305
+ # Reset logic
306
+ if reset_button:
307
+ st.session_state.start_index = 0
308
+ st.session_state.pred_vals = []
309
+ st.session_state.true_vals = []
310
+ st.session_state.pred_timestamps = []
311
+ st.session_state.true_timestamps = []
312
+ st.session_state.last_fig = None
313
+ st.session_state.is_running = False
314
+ st.session_state.valid_pos = 0
315
+ st.session_state.first_plot_shown = False
316
+ st.rerun()
317
+
318
+ # Auto-reset on critical parameter change while running
319
+ if st.session_state.is_running and (
320
+ start_date != st.session_state.get("last_start_date")
321
+ or end_date != st.session_state.get("last_end_date")
322
+ or model_choice != st.session_state.get("last_model_choice")
323
+ ):
324
+ st.session_state.start_index = 0
325
+ st.session_state.pred_vals = []
326
+ st.session_state.true_vals = []
327
+ st.session_state.pred_timestamps = []
328
+ st.session_state.true_timestamps = []
329
+ st.session_state.last_fig = None
330
+ st.session_state.valid_pos = 0
331
+ st.session_state.first_plot_shown = False
332
+ st.rerun()
333
+
334
+ # Track current selections for change detection
335
+ st.session_state.last_start_date = start_date
336
+ st.session_state.last_end_date = end_date
337
+ st.session_state.last_model_choice = model_choice
338
+
339
+
340
+ # ============================== Paused Mode ==============================
341
+
342
+ if not st.session_state.is_running and st.session_state.last_fig is not None:
343
+ st.write("Simulation paused...")
344
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
345
+
346
+ timestamp = (
347
+ st.session_state.pred_timestamps[-1]
348
+ if st.session_state.pred_timestamps
349
+ else "–"
350
+ )
351
+ prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None
352
+ actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None
353
+ progress = st.session_state.start_index / total_steps_ui
354
+
355
+ render_simulation_view(
356
+ timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True
357
+ )
358
+
359
+
360
+ # ============================== initialize values ==============================
361
+
362
+ # if lightGbm use testdata from above
363
+ if model_choice == "LightGBM":
364
+ test_df = test_df_filtered.copy()
365
+
366
+ # Shared state references for storing predictions and ground truths
367
+
368
+ true_vals = st.session_state.true_vals
369
+ pred_vals = st.session_state.pred_vals
370
+ true_timestamps = st.session_state.true_timestamps
371
+ pred_timestamps = st.session_state.pred_timestamps
372
+
373
+ # ============================== LightGBM Simulation ==============================
374
+
375
+ if model_choice == "LightGBM" and st.session_state.is_running:
376
+ model = load_cached_lightgbm_model()
377
+ st.write("Simulation started...")
378
+ st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
379
+
380
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
381
+
382
+ for i in range(st.session_state.start_index, len(test_df)):
383
+ if not st.session_state.is_running:
384
+ break
385
+
386
+ current = test_df.iloc[i]
387
+ timestamp = current["date"]
388
+ features = current[FEATURES].values.reshape(1, -1)
389
+ prediction = model.predict(features)[0]
390
+
391
+ pred_vals.append(prediction)
392
+ pred_timestamps.append(timestamp)
393
+
394
+ if i >= 1:
395
+ prev_actual = test_df.iloc[i - 1]["consumption_MW"]
396
+ prev_time = test_df.iloc[i - 1]["date"]
397
+ true_vals.append(prev_actual)
398
+ true_timestamps.append(prev_time)
399
+
400
+ fig = create_prediction_plot(
401
+ pred_timestamps,
402
+ pred_vals,
403
+ true_timestamps,
404
+ true_vals,
405
+ window_hours,
406
+ y_min=test_df_filtered["consumption_MW"].min() - 2000,
407
+ y_max=test_df_filtered["consumption_MW"].max() + 2000,
408
+ )
409
+
410
+ render_simulation_view(
411
+ timestamp,
412
+ prediction,
413
+ prev_actual if i >= 1 else None,
414
+ i / len(test_df),
415
+ fig,
416
+ )
417
+
418
+ plt.close(fig) # Speicher freigeben
419
+
420
+ st.session_state.start_index = i + 1
421
+ time.sleep(1 / (speed + 1e-9))
422
+
423
+ st.success("Simulation completed!")
424
+
425
+
426
+ # ============================== Transformer Simulation ==============================
427
+
428
+ spinner_placeholder = st.empty()
429
+
430
+ if model_choice == "Transformer Model (moments)":
431
+ if st.session_state.is_running:
432
+ st.write("Simulation started (Transformer)...")
433
+ st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
434
+
435
+ if not st.session_state.first_plot_shown:
436
+ spinner_placeholder.markdown("Running first prediction – please wait...")
437
+
438
+ plot_title, plot_container, x_axis_label, info_container = (
439
+ init_simulation_layout()
440
+ )
441
+
442
+ # Zugriff auf Modell, Dataset, Device
443
+ model, test_dataset, device = load_transformer_model_and_dataset()
444
+ data = test_dataset.data # bereits skaliert
445
+ scaler = test_dataset.scaler
446
+ n_channels = test_dataset.n_channels
447
+
448
+ test_start_idx = (
449
+ len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
450
+ + SEQ_LEN
451
+ )
452
+ base_timestamp = pd.read_csv(DATA_PATH, parse_dates=["date"])["date"].iloc[
453
+ test_start_idx
454
+ ] # get original timestamp for later, cause not in dataset anymore
455
+
456
+ # Schritt 1: Finde Index, ab dem Stunde = 00:00 ist
457
+ offset = 0
458
+ while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp(
459
+ "00:00:00"
460
+ ).time():
461
+ offset += 1
462
+
463
+ # Neuer Startindex in der Simulation
464
+ start_index = offset
465
+
466
+ # Session-State bei Bedarf initial setzen
467
+ if "start_index" not in st.session_state or st.session_state.start_index == 0:
468
+ st.session_state.start_index = start_index
469
+
470
+ # Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum
471
+ valid_indices = []
472
+ for i in range(start_index, len(test_dataset)):
473
+ timestamp = base_timestamp + pd.Timedelta(hours=i)
474
+ if start_date <= timestamp.date() <= end_date:
475
+ valid_indices.append(i)
476
+
477
+ # Fortschrittsanzeige
478
+ total_steps = len(valid_indices)
479
+
480
+ # Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!)
481
+ if "valid_pos" not in st.session_state:
482
+ st.session_state.valid_pos = 0
483
+
484
+ # Hauptschleife: Nur noch über gültige Indizes iterieren
485
+ for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos :]):
486
+
487
+ # for i in range(st.session_state.start_index, len(test_dataset)):
488
+ if not st.session_state.is_running:
489
+ break
490
+
491
+ current_pred = predict_transformer_step(model, test_dataset, i, device)
492
+ current_time = base_timestamp + pd.Timedelta(hours=i)
493
+
494
+ pred_vals.append(current_pred)
495
+ pred_timestamps.append(current_time)
496
+
497
+ if i >= 1:
498
+ prev_actual = test_dataset[i - 1][1][
499
+ 0, 0
500
+ ] # erster Forecast-Wert der letzten Zeile
501
+ # Rückskalieren
502
+ dummy_actual = np.zeros((1, n_channels))
503
+ dummy_actual[:, 0] = prev_actual
504
+ actual_val = scaler.inverse_transform(dummy_actual)[0, 0]
505
+
506
+ true_time = current_time - pd.Timedelta(hours=1)
507
+
508
+ if true_time >= pd.to_datetime(start_date):
509
+ true_vals.append(actual_val)
510
+ true_timestamps.append(true_time)
511
+
512
+ # Plot erzeugen
513
+ fig = create_prediction_plot(
514
+ pred_timestamps,
515
+ pred_vals,
516
+ true_timestamps,
517
+ true_vals,
518
+ window_hours,
519
+ y_min=test_df_filtered["consumption_MW"].min() - 2000,
520
+ y_max=test_df_filtered["consumption_MW"].max() + 2000,
521
+ )
522
+ if len(pred_vals) >= 2 and len(true_vals) >= 1:
523
+ render_simulation_view(
524
+ current_time,
525
+ current_pred,
526
+ actual_val if i >= 1 else None,
527
+ st.session_state.valid_pos / total_steps,
528
+ fig,
529
+ )
530
+ if not st.session_state.first_plot_shown:
531
+ spinner_placeholder.empty()
532
+ st.session_state.first_plot_shown = True
533
+
534
+ plt.close(fig) # Speicher freigeben
535
+
536
+ st.session_state.valid_pos += 1
537
+ time.sleep(1 / (speed + 1e-9))
538
+
539
+ st.success("Simulation completed!")
540
+
541
+
542
+ # ============================== Scroll Sync ==============================
543
+
544
+ st.markdown(
545
+ """
546
+ <script>
547
+ window.addEventListener("message", (event) => {
548
+ if (event.data.type === "save_scroll") {
549
+ const pyScroll = event.data.scrollY;
550
+ window.parent.postMessage({type: "streamlit:setComponentValue", value: pyScroll}, "*");
551
+ }
552
+ });
553
+ </script>
554
+ """,
555
+ unsafe_allow_html=True,
556
+ )
streamlit_simulation/config_streamlit.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config_streamlit
2
+ import os
3
+
4
+ # Base directory → points to the project root
5
+ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
+
7
+ # Model paths
8
+ MODEL_PATH_LIGHTGBM = os.path.join(
9
+ BASE_DIR, "lightgbm_model", "model", "lightgbm_final_model.pkl"
10
+ )
11
+ MODEL_PATH_TRANSFORMER = os.path.join(
12
+ BASE_DIR, "transformer_model", "model", "checkpoints", "model_final.pth"
13
+ )
14
+
15
+ # Data path
16
+ DATA_PATH = os.path.join(
17
+ BASE_DIR, "data", "processed", "energy_consumption_aggregated_cleaned.csv"
18
+ )
19
+
20
+ # Color palette for Streamlit layout
21
+ PLOT_COLOR = "#edf1f7" # Plot background color
22
+
23
+ # Constants
24
+ TRAIN_RATIO = 0.7 # Train/test split ratio used by both models
streamlit_simulation/utils/dummy.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_simulation/dummy.py
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ class DummyDataset:
7
+ def __init__(self, length=100):
8
+ self.data = np.zeros((length, 10)) # Dummydaten
9
+ self.scaler = DummyScaler()
10
+ self.n_channels = 1
11
+ self.length = length
12
+
13
+ def __len__(self):
14
+ return self.length
15
+
16
+ def __getitem__(self, idx):
17
+ timeseries = np.zeros((48, 1)) # (SEQ_LEN, Channels)
18
+ target = np.zeros((1, 1)) # Forecast target
19
+ mask = np.ones((48,)) # Dummy-Maske
20
+ return timeseries, target, mask
21
+
22
+
23
+ class DummyScaler:
24
+ def inverse_transform(self, x):
25
+ return x # keine Skalierung nötig
26
+
27
+
28
+ class DummyOutput:
29
+ def __init__(self, forecast_shape):
30
+ # gib einen echten Tensor zurück, wie vom echten Modell erwartet
31
+ self.forecast = torch.tensor(np.full(forecast_shape, 42.0), dtype=torch.float32)
32
+
33
+
34
+ class DummyTransformerModel:
35
+ def __call__(self, x_enc, input_mask):
36
+ batch_size, seq_len, channels = x_enc.shape
37
+ forecast_shape = (batch_size, 1, channels)
38
+ return DummyOutput(forecast_shape)
39
+
40
+
41
+ class DummyLightGBMModel:
42
+ def predict(self, X):
43
+ return np.zeros(len(X)) # ← gibt jetzt np.ndarray zurück
streamlit_simulation/utils/env.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv() # einmalig beim Import
6
+
7
+
8
+ def use_dummy() -> bool:
9
+ return os.getenv("USE_DUMMY_MODEL", "false").lower() == "true"
streamlit_simulation/utils_streamlit.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pandas as pd
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ from streamlit_simulation.config_streamlit import DATA_PATH
7
+
8
+ HF_REPO = "dlaj/energy-forecasting-files"
9
+ HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
10
+
11
+
12
+ def load_data():
13
+ # Prüfe, ob lokale Datei existiert
14
+ if not os.path.exists(DATA_PATH):
15
+ print(f"Lokale Datei nicht gefunden: {DATA_PATH}")
16
+ print("Lade von Hugging Face...")
17
+
18
+ # Lade von HF Hub
19
+ downloaded_path = hf_hub_download(
20
+ repo_id=HF_REPO,
21
+ filename=HF_FILENAME,
22
+ repo_type="dataset",
23
+ cache_dir="hf_cache", # Optional: lokaler Cache-Ordner
24
+ )
25
+
26
+ return pd.read_csv(downloaded_path, parse_dates=["date"])
27
+
28
+ print(f"Lade lokale Datei: {DATA_PATH}")
29
+ return pd.read_csv(DATA_PATH, parse_dates=["date"])
transformer_model/scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__.py
transformer_model/scripts/config_transformer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+
4
+ # Base Directory
5
+ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
+
7
+ # Data paths
8
+ DATA_PATH = os.path.join(
9
+ BASE_DIR, "..", "data", "processed", "energy_consumption_aggregated_cleaned.csv"
10
+ )
11
+
12
+ # Other paths
13
+ CHECKPOINT_DIR = os.path.join(BASE_DIR, "model", "checkpoints")
14
+ RESULTS_DIR = os.path.join(BASE_DIR, "results")
15
+
16
+
17
+ # ========== Model Settings ==========
18
+ SEQ_LEN = 512 # Input sequence length (number of time steps the model sees)
19
+ FORECAST_HORIZON = 1 # Number of future steps the model should predict
20
+ HEAD_DROPOUT = 0.1 # Dropout in the head to prevent overfitting
21
+ WEIGHT_DECAY = 0.0 # L2 regularization (0 means off)
22
+
23
+ # ========== Training Settings ==========
24
+ MAX_EPOCHS = 9 # Optimal number of epochs based on performance curve
25
+ BATCH_SIZE = 32 # Batch size for training and evaluation
26
+ LEARNING_RATE = 1e-4 # Base learning rate
27
+ MAX_LR = 1e-4 # Max LR for OneCycleLR scheduler
28
+ GRAD_CLIP = 5.0 # Gradient clipping threshold
29
+
30
+ # ========== Freezing Strategy ==========
31
+ FREEZE_ENCODER = True
32
+ FREEZE_EMBEDDER = True
33
+ FREEZE_HEAD = False # just unfreeze the last forecasting head for finetuning
transformer_model/scripts/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__
transformer_model/scripts/evaluation/evaluate.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluate.py
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from momentfm.utils.utils import control_randomness
11
+ from sklearn.metrics import mean_squared_error, r2_score
12
+ from tqdm import tqdm
13
+
14
+ from transformer_model.scripts.config_transformer import (DATA_PATH,
15
+ FORECAST_HORIZON,
16
+ RESULTS_DIR, SEQ_LEN)
17
+ from transformer_model.scripts.utils.check_device import check_device
18
+ from transformer_model.scripts.utils.informer_dataset_class import \
19
+ InformerDataset
20
+ from transformer_model.scripts.utils.load_final_model import \
21
+ load_final_transformer_model
22
+
23
+ # Setup logging
24
+ logging.basicConfig(
25
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
26
+ )
27
+
28
+
29
+ def evaluate():
30
+ control_randomness(seed=13)
31
+ # Set device
32
+ device, backend, scaler = check_device()
33
+ logging.info(f"Evaluation is running on: {backend} ({device})")
34
+
35
+ # Load final model
36
+ model, _ = load_final_transformer_model(device)
37
+
38
+ # Recreate training dataset to get the fitted scaler
39
+ train_dataset = InformerDataset(
40
+ data_split="train", random_seed=13, forecast_horizon=FORECAST_HORIZON
41
+ )
42
+
43
+ # Use its scaler in the test dataset
44
+ test_dataset = InformerDataset(
45
+ data_split="test", random_seed=13, forecast_horizon=FORECAST_HORIZON
46
+ )
47
+
48
+ test_dataset.scaler = train_dataset.scaler
49
+
50
+ test_loader = torch.utils.data.DataLoader(
51
+ test_dataset, batch_size=32, shuffle=False
52
+ )
53
+
54
+ trues, preds = [], []
55
+
56
+ with torch.no_grad():
57
+ for timeseries, forecast, input_mask in tqdm(
58
+ test_loader, desc="Evaluating on test set"
59
+ ):
60
+ timeseries = timeseries.float().to(device)
61
+ forecast = forecast.float().to(device)
62
+ input_mask = input_mask.to(device) # <- wichtig!
63
+
64
+ output = model(x_enc=timeseries, input_mask=input_mask)
65
+
66
+ trues.append(forecast.cpu().numpy())
67
+ preds.append(output.forecast.cpu().numpy())
68
+
69
+ trues = np.concatenate(trues, axis=0)
70
+ preds = np.concatenate(preds, axis=0)
71
+
72
+ # Extract only first feature (consumption)
73
+ true_values = trues[:, 0, :]
74
+ pred_values = preds[:, 0, :]
75
+
76
+ # Inverse normalization
77
+ n_features = test_dataset.n_channels
78
+ true_reshaped = np.column_stack(
79
+ [true_values.flatten()]
80
+ + [np.zeros_like(true_values.flatten())] * (n_features - 1)
81
+ )
82
+ pred_reshaped = np.column_stack(
83
+ [pred_values.flatten()]
84
+ + [np.zeros_like(pred_values.flatten())] * (n_features - 1)
85
+ )
86
+
87
+ true_original = test_dataset.scaler.inverse_transform(true_reshaped)[:, 0]
88
+ pred_original = test_dataset.scaler.inverse_transform(pred_reshaped)[:, 0]
89
+
90
+ # Build timestamp index, since date got cutted out in informerdataset we need original dataset and use the index of the beginning of testdata to get the date
91
+ csv_path = os.path.join(DATA_PATH)
92
+ df = pd.read_csv(csv_path, parse_dates=["date"])
93
+
94
+ train_len = len(train_dataset)
95
+ test_start_idx = train_len + SEQ_LEN
96
+ start_timestamp = df["date"].iloc[test_start_idx]
97
+ logging.info(f"[DEBUG] timestamp: {start_timestamp}")
98
+
99
+ timestamps = [
100
+ start_timestamp + pd.Timedelta(hours=i) for i in range(len(true_original))
101
+ ]
102
+
103
+ df = pd.DataFrame(
104
+ {
105
+ "Timestamp": timestamps,
106
+ "True Consumption (MW)": true_original,
107
+ "Predicted Consumption (MW)": pred_original,
108
+ }
109
+ )
110
+
111
+ # Save results to CSV
112
+ os.makedirs(RESULTS_DIR, exist_ok=True)
113
+ results_path = os.path.join(RESULTS_DIR, "test_results.csv")
114
+ df.to_csv(results_path, index=False)
115
+ logging.info(f"Saved prediction results to: {results_path}")
116
+
117
+ # Evaluation metrics
118
+ mse = mean_squared_error(
119
+ df["True Consumption (MW)"], df["Predicted Consumption (MW)"]
120
+ )
121
+ rmse = np.sqrt(mse)
122
+ mape = (
123
+ np.mean(
124
+ np.abs(
125
+ (df["True Consumption (MW)"] - df["Predicted Consumption (MW)"])
126
+ / df["True Consumption (MW)"]
127
+ )
128
+ )
129
+ * 100
130
+ )
131
+ r2 = r2_score(df["True Consumption (MW)"], df["Predicted Consumption (MW)"])
132
+
133
+ # Save metrics to JSON
134
+ metrics = {"RMSE": float(rmse), "MAPE": float(mape), "R2": float(r2)}
135
+ metrics_path = os.path.join(RESULTS_DIR, "evaluation_metrics.json")
136
+ with open(metrics_path, "w") as f:
137
+ json.dump(metrics, f)
138
+
139
+ logging.info(f"Saved evaluation metrics to: {metrics_path}")
140
+ logging.info(f"RMSE: {rmse:.3f} | MAPE: {mape:.2f}% | R²: {r2:.3f}")
141
+
142
+
143
+ if __name__ == "__main__":
144
+ evaluate()
transformer_model/scripts/evaluation/plot_metrics.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # plot_metrics.py
2
+
3
+ import json
4
+ import os
5
+
6
+ import matplotlib.pyplot as plt
7
+ import pandas as pd
8
+
9
+ from transformer_model.scripts.config_transformer import RESULTS_DIR
10
+
11
+ # === Plot 1: Training Metrics ===
12
+
13
+ # Load training metrics
14
+ training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
15
+ with open(training_metrics_path, "r") as f:
16
+ metrics = json.load(f)
17
+
18
+ train_losses = metrics["train_losses"]
19
+ test_mses = metrics["test_mses"]
20
+ test_maes = metrics["test_maes"]
21
+
22
+ plt.figure(figsize=(10, 6))
23
+ plt.plot(
24
+ range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue"
25
+ )
26
+ plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red")
27
+ plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green")
28
+ plt.xlabel("Epoch")
29
+ plt.ylabel("Loss / Metric")
30
+ plt.title("Training Loss vs Test Metrics")
31
+ plt.legend()
32
+ plt.grid(True)
33
+
34
+ plot_path = os.path.join(RESULTS_DIR, "training_plot.png")
35
+ plt.savefig(plot_path)
36
+ print(f"[Saved] Training metrics plot: {plot_path}")
37
+ plt.show()
38
+
39
+
40
+ # === Plot 2: Predictions vs Ground Truth (Full Range) ===
41
+
42
+ # Load comparison results
43
+ comparison_path = os.path.join(RESULTS_DIR, "test_results.csv")
44
+ df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"])
45
+
46
+ plt.figure(figsize=(15, 6))
47
+ plt.plot(
48
+ df_comparison["Timestamp"],
49
+ df_comparison["True Consumption (MW)"],
50
+ label="True",
51
+ color="darkblue",
52
+ )
53
+ plt.plot(
54
+ df_comparison["Timestamp"],
55
+ df_comparison["Predicted Consumption (MW)"],
56
+ label="Predicted",
57
+ color="red",
58
+ linestyle="--",
59
+ )
60
+ plt.title("Energy Consumption: Predictions vs Ground Truth")
61
+ plt.xlabel("Time")
62
+ plt.ylabel("Consumption (MW)")
63
+ plt.legend()
64
+ plt.grid(True)
65
+ plt.tight_layout()
66
+
67
+ plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png")
68
+ plt.savefig(plot_path)
69
+ print(f"[Saved] Full range comparison plot: {plot_path}")
70
+ plt.show()
71
+
72
+
73
+ # === Plot 3: Predictions vs Ground Truth (First Month) ===
74
+
75
+ first_month_start = df_comparison["Timestamp"].min()
76
+ first_month_end = first_month_start + pd.Timedelta(days=25)
77
+ df_first_month = df_comparison[
78
+ (df_comparison["Timestamp"] >= first_month_start)
79
+ & (df_comparison["Timestamp"] <= first_month_end)
80
+ ]
81
+
82
+ plt.figure(figsize=(15, 6))
83
+ plt.plot(
84
+ df_first_month["Timestamp"],
85
+ df_first_month["True Consumption (MW)"],
86
+ label="True",
87
+ color="darkblue",
88
+ )
89
+ plt.plot(
90
+ df_first_month["Timestamp"],
91
+ df_first_month["Predicted Consumption (MW)"],
92
+ label="Predicted",
93
+ color="red",
94
+ linestyle="--",
95
+ )
96
+ plt.title("Energy Consumption (First Month): Predictions vs Ground Truth")
97
+ plt.xlabel("Time")
98
+ plt.ylabel("Consumption (MW)")
99
+ plt.legend()
100
+ plt.grid(True)
101
+ plt.tight_layout()
102
+
103
+ plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png")
104
+ plt.savefig(plot_path)
105
+ print(f"[Saved] 1-Month comparison plot: {plot_path}")
106
+ plt.show()
transformer_model/scripts/training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__
transformer_model/scripts/training/load_basis_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load_basis_model.py
2
+ # Load and initialize the base MOMENT model before finetuning
3
+
4
+ import logging
5
+
6
+ import torch
7
+ from momentfm import MOMENTPipeline
8
+
9
+ from transformer_model.scripts.config_transformer import (FORECAST_HORIZON,
10
+ FREEZE_EMBEDDER,
11
+ FREEZE_ENCODER,
12
+ FREEZE_HEAD,
13
+ HEAD_DROPOUT,
14
+ SEQ_LEN,
15
+ WEIGHT_DECAY)
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
19
+
20
+
21
+ def load_moment_model():
22
+ """
23
+ Loads and configures the MOMENT model for forecasting.
24
+ """
25
+ logging.info("Loading MOMENT model...")
26
+ model = MOMENTPipeline.from_pretrained(
27
+ "AutonLab/MOMENT-1-large",
28
+ model_kwargs={
29
+ "task_name": "forecasting",
30
+ "forecast_horizon": FORECAST_HORIZON, # default = 1
31
+ "head_dropout": HEAD_DROPOUT, # default = 0.1
32
+ "weight_decay": WEIGHT_DECAY, # default = 0.0
33
+ "freeze_encoder": FREEZE_ENCODER, # default = True
34
+ "freeze_embedder": FREEZE_EMBEDDER, # default = True
35
+ "freeze_head": FREEZE_HEAD, # default = False
36
+ },
37
+ )
38
+
39
+ model.init()
40
+ logging.info("Model initialized successfully.")
41
+ return model
42
+
43
+
44
+ def print_trainable_params(model):
45
+ """
46
+ Logs all trainable (unfrozen) parameters of the model.
47
+ """
48
+ logging.info("Unfrozen parameters:")
49
+ for name, param in model.named_parameters():
50
+ if param.requires_grad:
51
+ logging.info(f" {name}")
52
+
53
+
54
+ def test_dummy_forward(model):
55
+ """
56
+ Performs a dummy forward pass to verify the model runs without error.
57
+ """
58
+ logging.info(
59
+ "Running dummy forward pass with random tensors to see if model is running."
60
+ )
61
+ dummy_x = torch.randn(16, 1, SEQ_LEN)
62
+ output = model(x_enc=dummy_x)
63
+ logging.info(f"Dummy forward pass successful.Output shape: {output.shape}")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ model = load_moment_model()
68
+ print_trainable_params(model)
69
+ test_dummy_forward(model)
transformer_model/scripts/training/train.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import time
7
+
8
+ import numpy as np
9
+ import torch
10
+ from momentfm.utils.utils import control_randomness
11
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
12
+ from tqdm import tqdm
13
+
14
+ from transformer_model.scripts.config_transformer import (CHECKPOINT_DIR,
15
+ GRAD_CLIP,
16
+ LEARNING_RATE,
17
+ MAX_EPOCHS, MAX_LR,
18
+ RESULTS_DIR)
19
+ from transformer_model.scripts.training.load_basis_model import \
20
+ load_moment_model
21
+ from transformer_model.scripts.utils.check_device import check_device
22
+ from transformer_model.scripts.utils.create_dataloaders import \
23
+ create_dataloaders
24
+
25
+ # === Setup logging ===
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
+ )
29
+
30
+
31
+ def train():
32
+ # Start timing
33
+ start_time = time.time()
34
+
35
+ # Setup device (CUDA / DirectML / CPU) and AMP scaler
36
+ device, backend, scaler = check_device()
37
+
38
+ # Load base model
39
+ model = load_moment_model().to(device)
40
+
41
+ # Set random seeds for reproducibility
42
+ control_randomness(seed=13)
43
+
44
+ # Setup loss function and optimizer
45
+ criterion = torch.nn.MSELoss().to(device)
46
+ optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
47
+
48
+ # Load data
49
+ train_loader, test_loader = create_dataloaders()
50
+
51
+ # Setup learning rate scheduler (OneCycle policy)
52
+ total_steps = len(train_loader) * MAX_EPOCHS
53
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
54
+ optimizer, max_lr=MAX_LR, total_steps=total_steps, pct_start=0.3
55
+ )
56
+
57
+ # Ensure output folders exist
58
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
59
+ os.makedirs(RESULTS_DIR, exist_ok=True)
60
+
61
+ # Store metrics
62
+ train_losses, test_mses, test_maes = [], [], []
63
+
64
+ best_mae = float("inf")
65
+ best_epoch = None
66
+ no_improve_epochs = 0
67
+ patience = 5
68
+
69
+ for epoch in range(MAX_EPOCHS):
70
+ model.train()
71
+ epoch_losses = []
72
+
73
+ for timeseries, forecast, input_mask in tqdm(
74
+ train_loader, desc=f"Epoch {epoch}"
75
+ ):
76
+ timeseries = timeseries.float().to(device)
77
+ input_mask = input_mask.to(device)
78
+ forecast = forecast.float().to(device)
79
+
80
+ # Zero gradients
81
+ optimizer.zero_grad(set_to_none=True)
82
+
83
+ # Forward pass (with AMP if enabled)
84
+ if scaler:
85
+ with torch.amp.autocast(device_type="cuda"):
86
+ output = model(x_enc=timeseries, input_mask=input_mask)
87
+ loss = criterion(output.forecast, forecast)
88
+ else:
89
+ output = model(x_enc=timeseries, input_mask=input_mask)
90
+ loss = criterion(output.forecast, forecast)
91
+
92
+ # Backward pass + optimization
93
+ if scaler:
94
+ scaler.scale(loss).backward()
95
+ scaler.unscale_(optimizer)
96
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
97
+ scaler.step(optimizer)
98
+ scaler.update()
99
+ else:
100
+ loss.backward()
101
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
102
+ optimizer.step()
103
+
104
+ epoch_losses.append(loss.item())
105
+
106
+ average_train_loss = np.mean(epoch_losses)
107
+ train_losses.append(average_train_loss)
108
+ logging.info(f"Epoch {epoch}: Train Loss = {average_train_loss:.4f}")
109
+
110
+ # === Evaluation ===
111
+ model.eval()
112
+ trues, preds = [], []
113
+
114
+ with torch.no_grad():
115
+ for timeseries, forecast, input_mask in test_loader:
116
+ timeseries = timeseries.float().to(device)
117
+ input_mask = input_mask.to(device)
118
+ forecast = forecast.float().to(device)
119
+
120
+ if scaler:
121
+ with torch.amp.autocast(device_type="cuda"):
122
+ output = model(x_enc=timeseries, input_mask=input_mask)
123
+ else:
124
+ output = model(x_enc=timeseries, input_mask=input_mask)
125
+
126
+ trues.append(forecast.detach().cpu().numpy())
127
+ preds.append(output.forecast.detach().cpu().numpy())
128
+
129
+ trues = np.concatenate(trues, axis=0)
130
+ preds = np.concatenate(preds, axis=0)
131
+
132
+ # Reshape for sklearn metrics
133
+ trues_2d = trues.reshape(trues.shape[0], -1)
134
+ preds_2d = preds.reshape(preds.shape[0], -1)
135
+
136
+ mse = mean_squared_error(trues_2d, preds_2d)
137
+ mae = mean_absolute_error(trues_2d, preds_2d)
138
+
139
+ test_mses.append(mse)
140
+ test_maes.append(mae)
141
+ logging.info(f"Epoch {epoch}: Test MSE = {mse:.4f}, MAE = {mae:.4f}")
142
+
143
+ # === Early Stopping Check ===
144
+ if mae < best_mae:
145
+ best_mae = mae
146
+ best_epoch = epoch
147
+ no_improve_epochs = 0
148
+
149
+ # Save best model
150
+ best_model_path = os.path.join(CHECKPOINT_DIR, "best_model.pth")
151
+ torch.save(model.state_dict(), best_model_path)
152
+ logging.info(
153
+ f"New best model saved to: {best_model_path} (MAE: {best_mae:.4f})"
154
+ )
155
+ else:
156
+ no_improve_epochs += 1
157
+ logging.info(f"No improvement in MAE for {no_improve_epochs} epoch(s).")
158
+
159
+ if no_improve_epochs >= patience:
160
+ logging.info("Early stopping triggered.")
161
+ break
162
+
163
+ # Save checkpoint
164
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch}.pth")
165
+ torch.save(model.state_dict(), checkpoint_path)
166
+
167
+ scheduler.step()
168
+
169
+ logging.info(f"Best model was at epoch {best_epoch} with MAE: {best_mae:.4f}")
170
+
171
+ # Save final model
172
+ final_model_path = os.path.join(CHECKPOINT_DIR, "model_final.pth")
173
+ torch.save(model.state_dict(), final_model_path)
174
+ logging.info(f"Final model saved to: {final_model_path}")
175
+ logging.info(f"Final Test MSE: {test_mses[-1]:.4f}, MAE: {test_maes[-1]:.4f}")
176
+
177
+ # Save training metrics
178
+ metrics = {
179
+ "train_losses": [float(x) for x in train_losses],
180
+ "test_mses": [float(x) for x in test_mses],
181
+ "test_maes": [float(x) for x in test_maes],
182
+ }
183
+
184
+ metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
185
+ with open(metrics_path, "w") as f:
186
+ json.dump(metrics, f)
187
+ logging.info(f"Training metrics saved to: {metrics_path}")
188
+
189
+ # Done
190
+ elapsed = time.time() - start_time
191
+ logging.info(f"Training complete in {elapsed / 60:.2f} minutes.")
192
+
193
+
194
+ # === Entry Point ===
195
+ if __name__ == "__main__":
196
+ try:
197
+ train()
198
+ except Exception as e:
199
+ logging.error(f"Training failed: {e}")
transformer_model/scripts/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__
transformer_model/scripts/utils/check_device.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import subprocess
3
+ import sys
4
+
5
+ import torch
6
+
7
+
8
+ def install_package(package_name):
9
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
10
+
11
+
12
+ def check_device():
13
+ # **Check for NVIDIA GPU (CUDA)**
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda") # Use NVIDIA GPU
16
+ backend = "CUDA (NVIDIA)"
17
+ mixed_precision = True # Use Automatic Mixed Precision (AMP)
18
+
19
+ # **If no NVIDIA GPU, check for AMD GPU (DirectML) only in Windows**
20
+ else:
21
+ try:
22
+ # Only try DirectML if the environment is Windows and DirectML is installed
23
+ if "win32" in sys.platform:
24
+ torch_directml = importlib.import_module("torch_directml")
25
+ if torch_directml.device_count() > 0:
26
+ device = torch_directml.device() # Use AMD GPU with DirectML
27
+ backend = "DirectML (AMD)"
28
+ mixed_precision = False # No AMP for AMD GPU
29
+ else:
30
+ raise ImportError # AMD GPU not found
31
+ else:
32
+ device = torch.device("cpu")
33
+ backend = "CPU"
34
+ mixed_precision = False # No AMP for CPU
35
+
36
+ except ImportError:
37
+ # If DirectML is not installed or AMD GPU not found
38
+ device = torch.device("cpu")
39
+ backend = "CPU"
40
+ mixed_precision = False # No AMP for CPU
41
+
42
+ # Print the chosen device info
43
+ print(f"Training is running on: {backend} ({device})")
44
+
45
+ # **Initialize scaler (only for NVIDIA)**
46
+ if mixed_precision:
47
+ scaler = torch.amp.GradScaler()
48
+ else:
49
+ scaler = None # No scaler needed for AMD/CPU
50
+
51
+ return device, backend, scaler
52
+
53
+
54
+ if __name__ == "__main__":
55
+ device, backend, scaler = check_device()
transformer_model/scripts/utils/create_dataloaders.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create_dataloaders.py
2
+
3
+ import logging
4
+
5
+ from momentfm.utils.utils import control_randomness
6
+ from torch.utils.data import DataLoader
7
+
8
+ from transformer_model.scripts.config_transformer import (BATCH_SIZE,
9
+ FORECAST_HORIZON)
10
+ from transformer_model.scripts.utils.informer_dataset_class import \
11
+ InformerDataset
12
+
13
+
14
+ def create_dataloaders():
15
+ logging.info("Setting random seeds...")
16
+ control_randomness(seed=13)
17
+
18
+ logging.info("Loading training dataset...")
19
+ train_dataset = InformerDataset(
20
+ data_split="train", random_seed=13, forecast_horizon=FORECAST_HORIZON
21
+ )
22
+ logging.info(
23
+ "Train set loaded — Samples: %d | Features: %d",
24
+ len(train_dataset),
25
+ train_dataset.n_channels,
26
+ )
27
+
28
+ logging.info("Loading test dataset...")
29
+ test_dataset = InformerDataset(
30
+ data_split="test", random_seed=13, forecast_horizon=FORECAST_HORIZON
31
+ )
32
+ logging.info(
33
+ "Test set loaded — Samples: %d | Features: %d",
34
+ len(test_dataset),
35
+ test_dataset.n_channels,
36
+ )
37
+
38
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
39
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
40
+
41
+ logging.info("Dataloaders created successfully.")
42
+ return train_loader, test_loader
43
+
44
+
45
+ if __name__ == "__main__":
46
+ create_dataloaders()
transformer_model/scripts/utils/informer_dataset_class.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # informer_dataset.py
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.preprocessing import StandardScaler
9
+
10
+ from transformer_model.scripts.config_transformer import DATA_PATH, SEQ_LEN
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ class InformerDataset:
16
+ def __init__(
17
+ self,
18
+ forecast_horizon: Optional[int],
19
+ data_split: str = "train",
20
+ data_stride_len: int = 1,
21
+ task_name: str = "forecasting",
22
+ random_seed: int = 42,
23
+ ):
24
+ """
25
+ Parameters
26
+ ----------
27
+ forecast_horizon : int
28
+ Length of the prediction sequence.
29
+ data_split : str
30
+ 'train' or 'test'.
31
+ data_stride_len : int
32
+ Stride length between time windows.
33
+ task_name : str
34
+ 'forecasting' or 'imputation'.
35
+ random_seed : int
36
+ For reproducibility.
37
+ """
38
+
39
+ self.seq_len = SEQ_LEN
40
+ self.forecast_horizon = forecast_horizon
41
+ self.full_file_path_and_name = DATA_PATH
42
+ self.data_split = data_split
43
+ self.data_stride_len = data_stride_len
44
+ self.task_name = task_name
45
+ self.random_seed = random_seed
46
+
47
+ self._read_data()
48
+
49
+ def _get_borders(self):
50
+ train_ratio = 0.7
51
+ n_train = int(self.length_timeseries_original * train_ratio)
52
+ n_test = self.length_timeseries_original - n_train
53
+
54
+ train_end = n_train
55
+ test_start = train_end - self.seq_len
56
+ test_end = test_start + n_test + self.seq_len
57
+
58
+ # logging.info(f"Train range: 0 to {train_end}")
59
+ # logging.info(f"Test range: {test_start} to {test_end}")
60
+
61
+ return slice(0, train_end), slice(test_start, test_end)
62
+
63
+ def _read_data(self):
64
+ self.scaler = StandardScaler()
65
+
66
+ df = pd.read_csv(self.full_file_path_and_name)
67
+ self.length_timeseries_original = df.shape[0]
68
+ self.n_channels = df.shape[1] - 1 # exclude timestamp column
69
+
70
+ df.drop(columns=["date"], inplace=True)
71
+ df = df.infer_objects(copy=False).interpolate(method="cubic")
72
+
73
+ data_splits = self._get_borders()
74
+ train_data = df[data_splits[0]]
75
+
76
+ self.scaler.fit(train_data.values)
77
+ df = self.scaler.transform(df.values)
78
+
79
+ if self.data_split == "train":
80
+ self.data = df[data_splits[0], :]
81
+ elif self.data_split == "test":
82
+ self.data = df[data_splits[1], :]
83
+
84
+ self.length_timeseries = self.data.shape[0]
85
+
86
+ # logging.info(f"{self.data_split.capitalize()} set loaded.")
87
+ # logging.info(f"Time series length: {self.length_timeseries}")
88
+ # logging.info(f"Number of features: {self.n_channels}")
89
+
90
+ def __getitem__(self, index):
91
+ seq_start = self.data_stride_len * index
92
+ seq_end = seq_start + self.seq_len
93
+ input_mask = np.ones(self.seq_len)
94
+
95
+ if self.task_name == "forecasting":
96
+ pred_end = seq_end + self.forecast_horizon
97
+
98
+ if pred_end > self.length_timeseries:
99
+ pred_end = self.length_timeseries
100
+ seq_end = seq_end - self.forecast_horizon
101
+ seq_start = seq_end - self.seq_len
102
+
103
+ timeseries = self.data[seq_start:seq_end, :].T
104
+ forecast = self.data[seq_end:pred_end, :].T
105
+
106
+ return timeseries, forecast, input_mask
107
+
108
+ elif self.task_name == "imputation":
109
+ if seq_end > self.length_timeseries:
110
+ seq_end = self.length_timeseries
111
+ seq_end = seq_end - self.seq_len
112
+
113
+ timeseries = self.data[seq_start:seq_end, :].T
114
+
115
+ return timeseries, input_mask
116
+
117
+ def __len__(self):
118
+ if self.task_name == "imputation":
119
+ return (self.length_timeseries - self.seq_len) // self.data_stride_len + 1
120
+ elif self.task_name == "forecasting":
121
+ return (
122
+ self.length_timeseries - self.seq_len - self.forecast_horizon
123
+ ) // self.data_stride_len + 1
transformer_model/scripts/utils/load_final_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from transformer_model.scripts.config_transformer import CHECKPOINT_DIR
8
+ from transformer_model.scripts.training.load_basis_model import \
9
+ load_moment_model
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+
14
+ # load model from checkpoint if available, else download it from hugging face
15
+ def load_real_transformer_model(device=None): # ⬅️ Name geändert
16
+ if device is None:
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ model = load_moment_model()
20
+ filename = "model_final.pth"
21
+ local_path = os.path.join(CHECKPOINT_DIR, filename)
22
+
23
+ if os.path.exists(local_path):
24
+ checkpoint_path = local_path
25
+ print("Loading model from local path...")
26
+ else:
27
+ print("Downloading model from Hugging Face Hub...")
28
+ checkpoint_path = hf_hub_download(
29
+ repo_id="dlaj/energy-forecasting-files", # passe ggf. an
30
+ filename=f"transformer_model/{filename}",
31
+ repo_type="dataset",
32
+ )
33
+
34
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
35
+ model.to(device)
36
+ model.eval()
37
+ logging.info(f"Model loaded from: {checkpoint_path}")
38
+
39
+ return model, device
transformer_model/scripts/utils/model_loader_wrapper.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from streamlit_simulation.utils.env import use_dummy
2
+ from transformer_model.scripts.config_transformer import FORECAST_HORIZON
3
+ from transformer_model.scripts.utils.informer_dataset_class import \
4
+ InformerDataset
5
+ from transformer_model.scripts.utils.load_final_model import \
6
+ load_real_transformer_model
7
+
8
+ try:
9
+ from streamlit_simulation.utils.dummy import (DummyDataset,
10
+ DummyTransformerModel)
11
+ except ImportError:
12
+ DummyTransformerModel = None
13
+ DummyDataset = None
14
+
15
+
16
+ def load_final_transformer_model():
17
+ if use_dummy():
18
+ if DummyTransformerModel is None:
19
+ raise ImportError("DummyTransformerModel not available")
20
+ return DummyTransformerModel(), "cpu"
21
+ else:
22
+ return load_real_transformer_model()
23
+
24
+
25
+ def load_model_and_dataset():
26
+ model, device = load_final_transformer_model()
27
+
28
+ if use_dummy():
29
+ if DummyDataset is None:
30
+ raise ImportError("DummyDataset not available")
31
+ dataset = DummyDataset(length=200)
32
+ else:
33
+ train_dataset = InformerDataset(
34
+ data_split="train", random_seed=13, forecast_horizon=FORECAST_HORIZON
35
+ )
36
+ test_dataset = InformerDataset(
37
+ data_split="test", random_seed=13, forecast_horizon=FORECAST_HORIZON
38
+ )
39
+ test_dataset.scaler = train_dataset.scaler
40
+ dataset = test_dataset
41
+
42
+ return model, dataset, device