3v324v23 commited on
Commit
6baf912
·
1 Parent(s): c7d76b0
.gitattributes DELETED
@@ -1 +0,0 @@
1
- lightgbm_model/model/lightgbm_final_model.pkl filter=lfs diff=lfs merge=lfs -text
 
 
.streamlit/config.toml DELETED
@@ -1,9 +0,0 @@
1
- [theme]
2
- base="light"
3
- primaryColor="#FF4B4B"
4
- backgroundColor="#f8f9fa"
5
- textColor="#004080"
6
- secondaryBackgroundColor="#edf1f7"
7
- font="sans serif"
8
-
9
-
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,12 +0,0 @@
1
- ---
2
- title: Energy Forecasting Demo
3
- emoji: ⚡
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: streamlit
7
- sdk_version: 1.30.0
8
- app_file: streamlit_simulation/app.py
9
- pinned: true
10
- license: apache-2.0
11
- short_description: Hourly energy consumption forecasting
12
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
lightgbm_model/model/lightgbm_final_model.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:52777b05bde0cc4665aac0d18993701769c84edaf0ffe9cb3b82049fd779b56d
3
- size 1534227
 
 
 
 
lightgbm_model/scripts/__init__.py DELETED
@@ -1 +0,0 @@
1
- # __init__.py
 
 
lightgbm_model/scripts/config_lightgbm.py DELETED
@@ -1,41 +0,0 @@
1
- # config.py
2
- import os
3
-
4
- # === Paths ===
5
- BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
- DATA_PATH = os.path.join(
7
- BASE_DIR, "..", "data", "processed", "energy_consumption_aggregated_cleaned.csv"
8
- )
9
- RESULTS_DIR = os.path.join(BASE_DIR, "results")
10
- MODEL_DIR = os.path.join(BASE_DIR, "model")
11
-
12
- # === Feature-Definition ===
13
- FEATURES = [
14
- "hour_sin",
15
- "hour_cos",
16
- "weekday_sin",
17
- "weekday_cos",
18
- "rolling_mean_6h",
19
- "month_sin",
20
- "month_cos",
21
- "temperature_c",
22
- "consumption_last_week",
23
- "consumption_yesterday",
24
- "consumption_last_hour",
25
- ]
26
- TARGET = "consumption_MW"
27
-
28
- # === Hyperparameters fpr LightGBM ===
29
- LIGHTGBM_PARAMS = {
30
- "learning_rate": 0.05,
31
- "num_leaves": 15,
32
- "max_depth": 5,
33
- "lambda_l1": 1.0,
34
- "lambda_l2": 0.0,
35
- "min_split_gain": 0.0,
36
- "n_estimators": 1000,
37
- "objective": "regression",
38
- }
39
-
40
- # === Early Stopping ===
41
- EARLY_STOPPING_ROUNDS = 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightgbm_model/scripts/eval/eval_lightgbm.py DELETED
@@ -1,156 +0,0 @@
1
- # eval_model.py
2
-
3
- import json
4
- import os
5
- import pickle
6
-
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import pandas as pd
10
- from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
11
-
12
- from lightgbm_model.scripts.config_lightgbm import DATA_PATH, RESULTS_DIR
13
- from lightgbm_model.scripts.utils import load_lightgbm_model
14
-
15
- # === Ergebnisse-Ordner vorbereiten ===
16
- os.makedirs(RESULTS_DIR, exist_ok=True)
17
-
18
- # === Modell und eval_result laden ===
19
- # Modell laden
20
- model = load_lightgbm_model()
21
-
22
- # Eval laden
23
- with open(os.path.join(RESULTS_DIR, "lightgbm_eval_result.pkl"), "rb") as f:
24
- eval_result = pickle.load(f)
25
- X_train = pd.read_csv(os.path.join(RESULTS_DIR, "X_train.csv"))
26
- X_test = pd.read_csv(os.path.join(RESULTS_DIR, "X_test.csv"))
27
- y_test = pd.read_csv(os.path.join(RESULTS_DIR, "y_test.csv"))
28
-
29
- # === Lernkurve ===
30
- train_rmse = eval_result["training"]["rmse"]
31
- valid_rmse = eval_result["valid_1"]["rmse"]
32
-
33
- plt.figure(figsize=(10, 5))
34
- plt.plot(train_rmse, label="Train RMSE")
35
- plt.plot(valid_rmse, label="Valid RMSE")
36
- plt.axvline(model.best_iteration_, color="gray", linestyle="--", label="Best Iteration")
37
- plt.xlabel("Boosting Round")
38
- plt.ylabel("RMSE")
39
- plt.title("LightGBM Learning Curve")
40
- plt.legend()
41
- plt.tight_layout()
42
- plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_learning_curve.png"))
43
- # plt.show()
44
-
45
- # === Metriken berechnen ===
46
- y_pred = model.predict(X_test)
47
- mae = mean_absolute_error(y_test, y_pred)
48
- rmse = np.sqrt(mean_squared_error(y_test, y_pred))
49
- mape = (
50
- np.mean(
51
- np.abs(
52
- (y_test.values.flatten() - y_pred)
53
- / np.where(y_test.values.flatten() == 0, 1e-10, y_test.values.flatten())
54
- )
55
- )
56
- * 100
57
- )
58
- r2 = r2_score(y_test, y_pred)
59
-
60
- print(f"Test MAPE: {mape:.5f} %")
61
- print(f"Test MAE: {mae:.5f}")
62
- print(f"Test RMSE: {rmse:.5f}")
63
- print(f"Test R2: {r2:.5f}")
64
-
65
- metrics = {
66
- "model": "LightGBM",
67
- "MAE": round(mae, 2),
68
- "RMSE": round(rmse, 2),
69
- "MAPE (%)": round(mape, 2),
70
- "R2": round(r2, 4),
71
- "unit": "MW",
72
- }
73
-
74
- # Pfad setzen
75
- output_path = os.path.join(RESULTS_DIR, "evaluation_metrics_lightgbm.json")
76
- # Speichern
77
- with open(output_path, "w") as f:
78
- json.dump(metrics, f, indent=4)
79
-
80
- print(f"Metriken gespeichert unter {output_path}")
81
-
82
- # === Feature Importance ===
83
- feature_importance = pd.DataFrame(
84
- {"Feature": X_train.columns, "Importance": model.feature_importances_}
85
- ).sort_values(by="Importance", ascending=False)
86
-
87
- plt.figure(figsize=(10, 6))
88
- plt.barh(feature_importance["Feature"], feature_importance["Importance"])
89
- plt.xlabel("Feature Importance")
90
- plt.title("LightGBM Feature Importance")
91
- plt.gca().invert_yaxis()
92
- plt.tight_layout()
93
- plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_feature_importance.png"))
94
- # plt.show()
95
-
96
- # === Vergleichsplots ===
97
- results_df = pd.DataFrame(
98
- {
99
- "True Consumption (MW)": y_test.values.flatten(),
100
- "Predicted Consumption (MW)": y_pred,
101
- }
102
- )
103
-
104
- # Timestamps anhängen
105
- full_df = pd.read_csv(DATA_PATH)
106
- test_dates = full_df.iloc[int(len(full_df) * 0.8) :]["date"].reset_index(drop=True)
107
- results_df["Timestamp"] = pd.to_datetime(test_dates)
108
-
109
- # Voller Plot
110
- plt.figure(figsize=(15, 6))
111
- plt.plot(
112
- results_df["Timestamp"],
113
- results_df["True Consumption (MW)"],
114
- label="True",
115
- color="darkblue",
116
- )
117
- plt.plot(
118
- results_df["Timestamp"],
119
- results_df["Predicted Consumption (MW)"],
120
- label="Predicted",
121
- color="red",
122
- linestyle="--",
123
- )
124
- plt.title("Predicted vs True Consumption")
125
- plt.xlabel("Timestamp")
126
- plt.ylabel("Consumption (MW)")
127
- plt.legend()
128
- plt.tight_layout()
129
- plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_comparison_plot.png"))
130
- # plt.show()
131
-
132
- # Subset Plot
133
- subset = results_df.iloc[: len(results_df) // 10]
134
- plt.figure(figsize=(15, 6))
135
- plt.plot(
136
- subset["Timestamp"], subset["True Consumption (MW)"], label="True", color="darkblue"
137
- )
138
- plt.plot(
139
- subset["Timestamp"],
140
- subset["Predicted Consumption (MW)"],
141
- label="Predicted",
142
- color="red",
143
- linestyle="--",
144
- )
145
- plt.title("Predicted vs True (First decile)")
146
- plt.xlabel("Timestamp")
147
- plt.ylabel("Consumption (MW)")
148
- plt.legend()
149
- plt.tight_layout()
150
- plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_prediction_with_timestamp.png"))
151
- # plt.show()
152
-
153
-
154
- # === Ens message ===
155
- print("\nEvaluation completed.")
156
- print(f"All Plots stored in:\n→ {RESULTS_DIR}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightgbm_model/scripts/model_loader_wrapper.py DELETED
@@ -1,11 +0,0 @@
1
- from lightgbm_model.scripts.utils import load_lightgbm_model as real_model
2
- from scripts.utils.env import use_dummy
3
-
4
-
5
- def load_lightgbm_model():
6
- if use_dummy():
7
- from scripts.utils.dummy import DummyLightGBMModel
8
-
9
- return DummyLightGBMModel()
10
- else:
11
- return real_model()
 
 
 
 
 
 
 
 
 
 
 
 
lightgbm_model/scripts/train/train_lightgbm.py DELETED
@@ -1,66 +0,0 @@
1
- # train_lightgbm.py
2
-
3
- import os
4
- import pickle
5
-
6
- import pandas as pd
7
- from lightgbm import LGBMRegressor, early_stopping, record_evaluation
8
-
9
- from lightgbm_model.scripts.config_lightgbm import (DATA_PATH,
10
- EARLY_STOPPING_ROUNDS,
11
- FEATURES, LIGHTGBM_PARAMS,
12
- MODEL_DIR, RESULTS_DIR,
13
- TARGET)
14
-
15
- # === Load Data ===
16
- df = pd.read_csv(DATA_PATH)
17
-
18
- # Drop date (used later for plots only)
19
- df = df.drop(columns=["date"], errors="ignore")
20
-
21
- # === Time-based Split (70% train, 10% valid, 20% test) ===
22
- train_size = int(len(df) * 0.7)
23
- valid_size = int(len(df) * 0.1)
24
- df_train = df.iloc[:train_size]
25
- df_valid = df.iloc[train_size : train_size + valid_size]
26
- df_test = df.iloc[train_size + valid_size :]
27
-
28
- X_train, y_train = df_train[FEATURES], df_train[TARGET]
29
- X_valid, y_valid = df_valid[FEATURES], df_valid[TARGET]
30
- X_test, y_test = df_test[FEATURES], df_test[TARGET]
31
-
32
-
33
- # === Init LightGBM model ===
34
- eval_result = {}
35
-
36
- model = LGBMRegressor(**LIGHTGBM_PARAMS, verbosity=-1)
37
-
38
- model.fit(
39
- X_train,
40
- y_train,
41
- eval_set=[(X_train, y_train), (X_valid, y_valid)],
42
- eval_metric="rmse",
43
- callbacks=[early_stopping(EARLY_STOPPING_ROUNDS), record_evaluation(eval_result)],
44
- )
45
-
46
- # === Save model ===
47
- os.makedirs(MODEL_DIR, exist_ok=True)
48
- model_path = os.path.join(MODEL_DIR, "lightgbm_final_model.pkl")
49
-
50
- with open(model_path, "wb") as f:
51
- pickle.dump(model, f)
52
-
53
- # === Save evaluation results ===
54
- os.makedirs(RESULTS_DIR, exist_ok=True)
55
- eval_result_path = os.path.join(RESULTS_DIR, "lightgbm_eval_result.pkl")
56
-
57
- with open(eval_result_path, "wb") as f:
58
- pickle.dump(eval_result, f)
59
-
60
- print(f"Model saved to: {model_path}")
61
- print(f"Eval results saved to: {eval_result_path}")
62
-
63
- # === Save data for evaluation ===
64
- X_train.to_csv(os.path.join(RESULTS_DIR, "X_train.csv"), index=False)
65
- X_test.to_csv(os.path.join(RESULTS_DIR, "X_test.csv"), index=False)
66
- y_test.to_csv(os.path.join(RESULTS_DIR, "y_test.csv"), index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightgbm_model/scripts/utils.py DELETED
@@ -1,9 +0,0 @@
1
- import os
2
- import pickle
3
-
4
- MODEL_PATH = os.path.join("lightgbm_model", "model", "lightgbm_final_model.pkl")
5
-
6
-
7
- def load_lightgbm_model():
8
- with open(MODEL_PATH, "rb") as f:
9
- return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,38 +0,0 @@
1
- # =============================
2
- # Requirements for Energy Prediction Project
3
- # =============================
4
-
5
- # Python 3.11 environment recommended since moments dont work with later versions
6
-
7
- # Moment Foundation Model (forecasting backbone)
8
- momentfm @ git+https://github.com/moment-timeseries-foundation-model/moment.git@37a8bde4eb3dd340bebc9b54a3b893bcba62cd4f
9
-
10
- # === Core Python stack ===
11
- numpy==1.25.2 # Numerical operations
12
- pandas==2.2.2 # Data manipulation and analysis
13
- matplotlib==3.10.0 # Plotting and visualizations
14
-
15
-
16
- # === Machine Learning ===
17
- scikit-learn==1.6.1 # Evaluation metrics and preprocessing utilities
18
- torch==2.6.0 # PyTorch with CUDA 12.4 (GPU support)
19
- #torchvision==0.21.0 # Optional (can support visual tasks, not critical here)
20
- #torchaudio==2.6.0 # Optional (comes with torch install, can stay)
21
-
22
- # === Utilities ===
23
- tqdm==4.67.1 # Progress bars
24
- ipywidgets>=8.0 # Enables tqdm progress bars in Jupyter/Colab
25
- pprintpp==0.4.0 # Prettier print formatting for nested dicts (used for model output check)
26
-
27
- # === lightgbm ===
28
- lightgbm==4.3.0 # Boosted Trees for tabular modeling (used for baseline and feature selection)
29
-
30
- # === Streamlit App ===
31
- streamlit>=1.30.0
32
- plotly>=5.0.0
33
-
34
- # === for pytest/env dummy/pre-commit/huggingface ====
35
- pytest
36
- python-dotenv
37
- pre-commit
38
- huggingface_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/dummy.py DELETED
@@ -1,43 +0,0 @@
1
- # streamlit_simulation/dummy.py
2
- import numpy as np
3
- import torch
4
-
5
-
6
- class DummyDataset:
7
- def __init__(self, length=100):
8
- self.data = np.zeros((length, 10)) # Dummydaten
9
- self.scaler = DummyScaler()
10
- self.n_channels = 1
11
- self.length = length
12
-
13
- def __len__(self):
14
- return self.length
15
-
16
- def __getitem__(self, idx):
17
- timeseries = np.zeros((48, 1)) # (SEQ_LEN, Channels)
18
- target = np.zeros((1, 1)) # Forecast target
19
- mask = np.ones((48,)) # Dummy-Maske
20
- return timeseries, target, mask
21
-
22
-
23
- class DummyScaler:
24
- def inverse_transform(self, x):
25
- return x # keine Skalierung nötig
26
-
27
-
28
- class DummyOutput:
29
- def __init__(self, forecast_shape):
30
- # gib einen echten Tensor zurück, wie vom echten Modell erwartet
31
- self.forecast = torch.tensor(np.full(forecast_shape, 42.0), dtype=torch.float32)
32
-
33
-
34
- class DummyTransformerModel:
35
- def __call__(self, x_enc, input_mask):
36
- batch_size, seq_len, channels = x_enc.shape
37
- forecast_shape = (batch_size, 1, channels)
38
- return DummyOutput(forecast_shape)
39
-
40
-
41
- class DummyLightGBMModel:
42
- def predict(self, X):
43
- return np.zeros(len(X)) # ← gibt jetzt np.ndarray zurück
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/env.py DELETED
@@ -1,9 +0,0 @@
1
- import os
2
-
3
- from dotenv import load_dotenv
4
-
5
- load_dotenv() # einmalig beim Import
6
-
7
-
8
- def use_dummy() -> bool:
9
- return os.getenv("USE_DUMMY_MODEL", "false").lower() == "true"
 
 
 
 
 
 
 
 
 
 
setup.py DELETED
@@ -1,7 +0,0 @@
1
- from setuptools import find_packages, setup
2
-
3
- setup(
4
- name="energy_prediction",
5
- version="0.1",
6
- packages=find_packages(),
7
- )
 
 
 
 
 
 
 
 
streamlit_simulation/__init__.py DELETED
@@ -1 +0,0 @@
1
- # __init__.py
 
 
streamlit_simulation/app.py DELETED
@@ -1,556 +0,0 @@
1
- import time
2
- import warnings
3
-
4
- import matplotlib.dates as mdates
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import pandas as pd
8
- import streamlit as st
9
- import torch
10
- from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
11
-
12
- from lightgbm_model.scripts.config_lightgbm import FEATURES
13
- from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
14
- from streamlit_simulation.utils_streamlit import load_data as load_data_raw
15
- from transformer_model.scripts.config_transformer import (FORECAST_HORIZON,
16
- SEQ_LEN)
17
- from transformer_model.scripts.utils.informer_dataset_class import \
18
- InformerDataset
19
- from transformer_model.scripts.utils.model_loader_wrapper import \
20
- load_model_and_dataset
21
-
22
- # ============================== Layout ==============================
23
-
24
- # Streamlit & warnings config
25
- warnings.filterwarnings("ignore", category=FutureWarning)
26
- st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide")
27
-
28
- # CSS part
29
- st.markdown(
30
- f"""
31
- <style>
32
- .stButton > button {{
33
- background-color: {PLOT_COLOR};
34
- }}
35
-
36
- /* Entfernt auch den leeren Platz über der App */
37
- header[data-testid="stHeader"] {{
38
- display: none !important;
39
- height: 0px !important;
40
- visibility: hidden !important;
41
- }}
42
-
43
- .block-container {{
44
- padding-top: 0.5rem !important;
45
- }}
46
-
47
- </style>
48
- """,
49
- unsafe_allow_html=True,
50
- )
51
-
52
-
53
- st.title("Electricity Consumption Forecast: Hourly Simulation")
54
- st.write("Welcome to the simulation interface!")
55
- st.info(
56
- "**Simulation Overview:**\n\n"
57
- "This dashboard provides an hourly electricity consumption forecast using two different models: "
58
- "**LightGBM** and a **Transformer (moment-based)**. Both models generate a fresh prediction at every time step "
59
- "(i.e., every simulated hour).\n\n"
60
- "Note: Since this app runs on a limited CPU on Hugging Face Spaces, the Transformer model may respond slower "
61
- "compared to local execution. On a standard local CPU, performance is significantly better."
62
- )
63
-
64
-
65
- # ============================== Session State Init ===============================
66
- def init_session_state():
67
- defaults = {
68
- "is_running": False,
69
- "start_index": 0,
70
- "true_vals": [],
71
- "pred_vals": [],
72
- "true_timestamps": [],
73
- "pred_timestamps": [],
74
- "last_fig": None,
75
- "valid_pos": 0,
76
- "first_plot_shown": False,
77
- }
78
- for key, value in defaults.items():
79
- if key not in st.session_state:
80
- st.session_state[key] = value
81
-
82
-
83
- init_session_state()
84
-
85
-
86
- # ============================== Loaders Cache ==============================
87
- @st.cache_data
88
- def load_cached_lightgbm_model():
89
- return load_lightgbm_model()
90
-
91
-
92
- @st.cache_resource
93
- def load_transformer_model_and_dataset():
94
- return load_model_and_dataset()
95
-
96
-
97
- @st.cache_data
98
- def load_data():
99
- return load_data_raw()
100
-
101
-
102
- # ============================== Utility Functions ==============================
103
-
104
-
105
- def predict_transformer_step(model, dataset, idx, device):
106
- """Performs a single prediction step with the transformer model."""
107
- timeseries, _, input_mask = dataset[idx]
108
- timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device)
109
- input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device)
110
-
111
- with torch.no_grad():
112
- output = model(x_enc=timeseries, input_mask=input_mask)
113
-
114
- pred = output.forecast[:, 0, :].cpu().numpy().flatten()
115
-
116
- # Rückskalieren
117
- dummy = np.zeros((len(pred), dataset.n_channels))
118
- dummy[:, 0] = pred
119
- pred_original = dataset.scaler.inverse_transform(dummy)[:, 0]
120
-
121
- return float(pred_original[0])
122
-
123
-
124
- def init_simulation_layout():
125
- """Creates layout containers for plot and info sections."""
126
- col1, spacer, col2 = st.columns([3, 0.2, 1])
127
- plot_title = col1.empty()
128
- plot_container = col1.empty()
129
- x_axis_label = col1.empty()
130
- info_container = col2.empty()
131
- return plot_title, plot_container, x_axis_label, info_container
132
-
133
-
134
- def create_prediction_plot(
135
- pred_timestamps,
136
- pred_vals,
137
- true_timestamps,
138
- true_vals,
139
- window_hours,
140
- y_min=None,
141
- y_max=None,
142
- ):
143
- """Generates the matplotlib figure for plotting prediction vs. actual."""
144
- fig, ax = plt.subplots(
145
- figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR
146
- )
147
- ax.set_facecolor(PLOT_COLOR)
148
-
149
- ax.plot(
150
- pred_timestamps[-window_hours:],
151
- pred_vals[-window_hours:],
152
- label="Prediction",
153
- color="#EF233C",
154
- linestyle="--",
155
- )
156
- if true_vals:
157
- ax.plot(
158
- true_timestamps[-window_hours:],
159
- true_vals[-window_hours:],
160
- label="Actual",
161
- color="#0077B6",
162
- )
163
-
164
- ax.set_ylabel("Consumption (MW)", fontsize=8)
165
- ax.legend(
166
- fontsize=8,
167
- loc="upper left",
168
- bbox_to_anchor=(0, 0.95),
169
- # facecolor= INPUT_BG, # INPUT_BG
170
- # edgecolor= ACCENT_COLOR, # ACCENT_COLOR
171
- # labelcolor= TEXT_COLOR # TEXT_COLOR
172
- )
173
- ax.yaxis.grid(True, linestyle=":", linewidth=0.5, alpha=0.7)
174
- ax.set_ylim(y_min, y_max)
175
- ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))
176
- ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d"))
177
- ax.tick_params(axis="x", labelrotation=0, labelsize=5)
178
- ax.tick_params(axis="y", labelsize=5)
179
- # fig.patch.set_facecolor('#e6ecf0') # outer area
180
-
181
- for spine in ax.spines.values():
182
- spine.set_visible(False)
183
-
184
- st.session_state.last_fig = fig
185
- return fig
186
-
187
-
188
- def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False):
189
- """Displays the simulation plot and metrics in the UI."""
190
- title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction"
191
- plot_title.markdown(
192
- f"<div style='text-align: center; font-size: 20pt; font-weight: bold; margin-bottom: -0.7rem; margin-top: 0rem;'>"
193
- f"{title}</div>",
194
- unsafe_allow_html=True,
195
- )
196
- plot_container.pyplot(fig)
197
-
198
- # st.markdown("<div style='margin-bottom: 0.5rem;'></div>", unsafe_allow_html=True)
199
- # x_axis_label.markdown(f"<div style='text-align: center; font-size: 13pt; color: {TEXT_COLOR}; margin-top: -0.5rem;'>"f"Time</div>",unsafe_allow_html=True)
200
-
201
- with info_container.container():
202
- st.markdown(
203
- f"<span style='font-size: 24px; font-weight: 600;'>Time: {timestamp}</span>",
204
- unsafe_allow_html=True,
205
- )
206
- st.metric(
207
- "Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–"
208
- )
209
- st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–")
210
- st.caption("Simulation Progress")
211
- st.progress(progress)
212
-
213
- if len(st.session_state.true_vals) > 1:
214
- true_arr = np.array(st.session_state.true_vals)
215
- pred_arr = np.array(st.session_state.pred_vals[:-1])
216
- min_len = min(len(true_arr), len(pred_arr))
217
- if min_len >= 1:
218
- errors = np.abs(true_arr[:min_len] - pred_arr[:min_len])
219
- mape = (
220
- np.mean(
221
- errors
222
- / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])
223
- )
224
- * 100
225
- )
226
- mae = np.mean(errors)
227
- max_error = np.max(errors)
228
-
229
- st.divider()
230
- st.markdown(
231
- "<span style='font-size: 24px; font-weight: 600; '>Interim Metrics</span>",
232
- unsafe_allow_html=True,
233
- )
234
- st.metric("MAPE (so far)", f"{mape:.2f} %")
235
- st.metric("MAE (so far)", f"{mae:,.0f} MW")
236
- st.metric("Max Error", f"{max_error:,.0f} MW")
237
-
238
-
239
- # ============================== Data Preparation ==============================
240
-
241
- df_full = load_data()
242
-
243
- # Split Train/Test
244
- train_size = int(len(df_full) * TRAIN_RATIO)
245
- test_df_raw = df_full.iloc[train_size:].reset_index(drop=True)
246
-
247
- # Start at first full hour (00:00)
248
- first_full_day_index = test_df_raw[
249
- test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time()
250
- ].index[0]
251
- test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True)
252
-
253
- # Select simulation window via date picker
254
- min_date = test_df_full["date"].min().date()
255
- max_date = test_df_full["date"].max().date()
256
-
257
- # ============================== UI Controls ==============================
258
-
259
- with st.sidebar:
260
- st.header("⚙️ Simulation Settings")
261
-
262
- st.subheader("General Settings")
263
- model_choice = st.selectbox(
264
- "Choose prediction model", ["LightGBM", "Transformer Model (moments)"]
265
- )
266
- if model_choice == "Transformer Model (moments)":
267
- st.caption(
268
- "⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)"
269
- )
270
- window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0)
271
- window_hours = window_days * 24
272
- speed = st.slider("Speed", 1, 10, 5)
273
-
274
- st.subheader("Date Range")
275
- start_date = st.date_input(
276
- "Start Date", value=min_date, min_value=min_date, max_value=max_date
277
- )
278
- end_date = st.date_input(
279
- "End Date", value=max_date, min_value=min_date, max_value=max_date
280
- )
281
-
282
- # ============================== Data Preparation (filtered) ==============================
283
-
284
- # final filtered date window
285
- test_df_filtered = test_df_full[
286
- (test_df_full["date"].dt.date >= start_date)
287
- & (test_df_full["date"].dt.date <= end_date)
288
- ].reset_index(drop=True)
289
-
290
- # For progression bar
291
- total_steps_ui = len(test_df_filtered)
292
-
293
- # ============================== Buttons ==============================
294
-
295
- st.markdown("### Start Simulation")
296
- col1, col2, col3 = st.columns([1, 1, 4])
297
- with col1:
298
- play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause"
299
- if st.button(play_pause_text, use_container_width=True):
300
- st.session_state.is_running = not st.session_state.is_running
301
- st.rerun()
302
- with col2:
303
- reset_button = st.button("🔄 Reset", use_container_width=True)
304
-
305
- # Reset logic
306
- if reset_button:
307
- st.session_state.start_index = 0
308
- st.session_state.pred_vals = []
309
- st.session_state.true_vals = []
310
- st.session_state.pred_timestamps = []
311
- st.session_state.true_timestamps = []
312
- st.session_state.last_fig = None
313
- st.session_state.is_running = False
314
- st.session_state.valid_pos = 0
315
- st.session_state.first_plot_shown = False
316
- st.rerun()
317
-
318
- # Auto-reset on critical parameter change while running
319
- if st.session_state.is_running and (
320
- start_date != st.session_state.get("last_start_date")
321
- or end_date != st.session_state.get("last_end_date")
322
- or model_choice != st.session_state.get("last_model_choice")
323
- ):
324
- st.session_state.start_index = 0
325
- st.session_state.pred_vals = []
326
- st.session_state.true_vals = []
327
- st.session_state.pred_timestamps = []
328
- st.session_state.true_timestamps = []
329
- st.session_state.last_fig = None
330
- st.session_state.valid_pos = 0
331
- st.session_state.first_plot_shown = False
332
- st.rerun()
333
-
334
- # Track current selections for change detection
335
- st.session_state.last_start_date = start_date
336
- st.session_state.last_end_date = end_date
337
- st.session_state.last_model_choice = model_choice
338
-
339
-
340
- # ============================== Paused Mode ==============================
341
-
342
- if not st.session_state.is_running and st.session_state.last_fig is not None:
343
- st.write("Simulation paused...")
344
- plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
345
-
346
- timestamp = (
347
- st.session_state.pred_timestamps[-1]
348
- if st.session_state.pred_timestamps
349
- else "–"
350
- )
351
- prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None
352
- actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None
353
- progress = st.session_state.start_index / total_steps_ui
354
-
355
- render_simulation_view(
356
- timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True
357
- )
358
-
359
-
360
- # ============================== initialize values ==============================
361
-
362
- # if lightGbm use testdata from above
363
- if model_choice == "LightGBM":
364
- test_df = test_df_filtered.copy()
365
-
366
- # Shared state references for storing predictions and ground truths
367
-
368
- true_vals = st.session_state.true_vals
369
- pred_vals = st.session_state.pred_vals
370
- true_timestamps = st.session_state.true_timestamps
371
- pred_timestamps = st.session_state.pred_timestamps
372
-
373
- # ============================== LightGBM Simulation ==============================
374
-
375
- if model_choice == "LightGBM" and st.session_state.is_running:
376
- model = load_cached_lightgbm_model()
377
- st.write("Simulation started...")
378
- st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
379
-
380
- plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
381
-
382
- for i in range(st.session_state.start_index, len(test_df)):
383
- if not st.session_state.is_running:
384
- break
385
-
386
- current = test_df.iloc[i]
387
- timestamp = current["date"]
388
- features = current[FEATURES].values.reshape(1, -1)
389
- prediction = model.predict(features)[0]
390
-
391
- pred_vals.append(prediction)
392
- pred_timestamps.append(timestamp)
393
-
394
- if i >= 1:
395
- prev_actual = test_df.iloc[i - 1]["consumption_MW"]
396
- prev_time = test_df.iloc[i - 1]["date"]
397
- true_vals.append(prev_actual)
398
- true_timestamps.append(prev_time)
399
-
400
- fig = create_prediction_plot(
401
- pred_timestamps,
402
- pred_vals,
403
- true_timestamps,
404
- true_vals,
405
- window_hours,
406
- y_min=test_df_filtered["consumption_MW"].min() - 2000,
407
- y_max=test_df_filtered["consumption_MW"].max() + 2000,
408
- )
409
-
410
- render_simulation_view(
411
- timestamp,
412
- prediction,
413
- prev_actual if i >= 1 else None,
414
- i / len(test_df),
415
- fig,
416
- )
417
-
418
- plt.close(fig) # Speicher freigeben
419
-
420
- st.session_state.start_index = i + 1
421
- time.sleep(1 / (speed + 1e-9))
422
-
423
- st.success("Simulation completed!")
424
-
425
-
426
- # ============================== Transformer Simulation ==============================
427
-
428
- spinner_placeholder = st.empty()
429
-
430
- if model_choice == "Transformer Model (moments)":
431
- if st.session_state.is_running:
432
- st.write("Simulation started (Transformer)...")
433
- st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
434
-
435
- if not st.session_state.first_plot_shown:
436
- spinner_placeholder.markdown("Running first prediction – please wait...")
437
-
438
- plot_title, plot_container, x_axis_label, info_container = (
439
- init_simulation_layout()
440
- )
441
-
442
- # Zugriff auf Modell, Dataset, Device
443
- model, test_dataset, device = load_transformer_model_and_dataset()
444
- data = test_dataset.data # bereits skaliert
445
- scaler = test_dataset.scaler
446
- n_channels = test_dataset.n_channels
447
-
448
- test_start_idx = (
449
- len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
450
- + SEQ_LEN
451
- )
452
- base_timestamp = pd.read_csv(DATA_PATH, parse_dates=["date"])["date"].iloc[
453
- test_start_idx
454
- ] # get original timestamp for later, cause not in dataset anymore
455
-
456
- # Schritt 1: Finde Index, ab dem Stunde = 00:00 ist
457
- offset = 0
458
- while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp(
459
- "00:00:00"
460
- ).time():
461
- offset += 1
462
-
463
- # Neuer Startindex in der Simulation
464
- start_index = offset
465
-
466
- # Session-State bei Bedarf initial setzen
467
- if "start_index" not in st.session_state or st.session_state.start_index == 0:
468
- st.session_state.start_index = start_index
469
-
470
- # Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum
471
- valid_indices = []
472
- for i in range(start_index, len(test_dataset)):
473
- timestamp = base_timestamp + pd.Timedelta(hours=i)
474
- if start_date <= timestamp.date() <= end_date:
475
- valid_indices.append(i)
476
-
477
- # Fortschrittsanzeige
478
- total_steps = len(valid_indices)
479
-
480
- # Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!)
481
- if "valid_pos" not in st.session_state:
482
- st.session_state.valid_pos = 0
483
-
484
- # Hauptschleife: Nur noch über gültige Indizes iterieren
485
- for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos :]):
486
-
487
- # for i in range(st.session_state.start_index, len(test_dataset)):
488
- if not st.session_state.is_running:
489
- break
490
-
491
- current_pred = predict_transformer_step(model, test_dataset, i, device)
492
- current_time = base_timestamp + pd.Timedelta(hours=i)
493
-
494
- pred_vals.append(current_pred)
495
- pred_timestamps.append(current_time)
496
-
497
- if i >= 1:
498
- prev_actual = test_dataset[i - 1][1][
499
- 0, 0
500
- ] # erster Forecast-Wert der letzten Zeile
501
- # Rückskalieren
502
- dummy_actual = np.zeros((1, n_channels))
503
- dummy_actual[:, 0] = prev_actual
504
- actual_val = scaler.inverse_transform(dummy_actual)[0, 0]
505
-
506
- true_time = current_time - pd.Timedelta(hours=1)
507
-
508
- if true_time >= pd.to_datetime(start_date):
509
- true_vals.append(actual_val)
510
- true_timestamps.append(true_time)
511
-
512
- # Plot erzeugen
513
- fig = create_prediction_plot(
514
- pred_timestamps,
515
- pred_vals,
516
- true_timestamps,
517
- true_vals,
518
- window_hours,
519
- y_min=test_df_filtered["consumption_MW"].min() - 2000,
520
- y_max=test_df_filtered["consumption_MW"].max() + 2000,
521
- )
522
- if len(pred_vals) >= 2 and len(true_vals) >= 1:
523
- render_simulation_view(
524
- current_time,
525
- current_pred,
526
- actual_val if i >= 1 else None,
527
- st.session_state.valid_pos / total_steps,
528
- fig,
529
- )
530
- if not st.session_state.first_plot_shown:
531
- spinner_placeholder.empty()
532
- st.session_state.first_plot_shown = True
533
-
534
- plt.close(fig) # Speicher freigeben
535
-
536
- st.session_state.valid_pos += 1
537
- time.sleep(1 / (speed + 1e-9))
538
-
539
- st.success("Simulation completed!")
540
-
541
-
542
- # ============================== Scroll Sync ==============================
543
-
544
- st.markdown(
545
- """
546
- <script>
547
- window.addEventListener("message", (event) => {
548
- if (event.data.type === "save_scroll") {
549
- const pyScroll = event.data.scrollY;
550
- window.parent.postMessage({type: "streamlit:setComponentValue", value: pyScroll}, "*");
551
- }
552
- });
553
- </script>
554
- """,
555
- unsafe_allow_html=True,
556
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
streamlit_simulation/config_streamlit.py DELETED
@@ -1,24 +0,0 @@
1
- # config_streamlit
2
- import os
3
-
4
- # Base directory → points to the project root
5
- BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
-
7
- # Model paths
8
- MODEL_PATH_LIGHTGBM = os.path.join(
9
- BASE_DIR, "lightgbm_model", "model", "lightgbm_final_model.pkl"
10
- )
11
- MODEL_PATH_TRANSFORMER = os.path.join(
12
- BASE_DIR, "transformer_model", "model", "checkpoints", "model_final.pth"
13
- )
14
-
15
- # Data path
16
- DATA_PATH = os.path.join(
17
- BASE_DIR, "data", "processed", "energy_consumption_aggregated_cleaned.csv"
18
- )
19
-
20
- # Color palette for Streamlit layout
21
- PLOT_COLOR = "#edf1f7" # Plot background color
22
-
23
- # Constants
24
- TRAIN_RATIO = 0.7 # Train/test split ratio used by both models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
streamlit_simulation/utils_streamlit.py DELETED
@@ -1,9 +0,0 @@
1
- # utils/data_utils.py
2
- import pandas as pd
3
-
4
- from streamlit_simulation.config_streamlit import DATA_PATH
5
-
6
-
7
- def load_data():
8
- df = pd.read_csv(DATA_PATH, parse_dates=["date"])
9
- return df
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/__init__.py DELETED
@@ -1 +0,0 @@
1
- # __init__.py
 
 
transformer_model/scripts/config_transformer.py DELETED
@@ -1,33 +0,0 @@
1
- # config.py
2
- import os
3
-
4
- # Base Directory
5
- BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
-
7
- # Data paths
8
- DATA_PATH = os.path.join(
9
- BASE_DIR, "..", "data", "processed", "energy_consumption_aggregated_cleaned.csv"
10
- )
11
-
12
- # Other paths
13
- CHECKPOINT_DIR = os.path.join(BASE_DIR, "model", "checkpoints")
14
- RESULTS_DIR = os.path.join(BASE_DIR, "results")
15
-
16
-
17
- # ========== Model Settings ==========
18
- SEQ_LEN = 512 # Input sequence length (number of time steps the model sees)
19
- FORECAST_HORIZON = 1 # Number of future steps the model should predict
20
- HEAD_DROPOUT = 0.1 # Dropout in the head to prevent overfitting
21
- WEIGHT_DECAY = 0.0 # L2 regularization (0 means off)
22
-
23
- # ========== Training Settings ==========
24
- MAX_EPOCHS = 9 # Optimal number of epochs based on performance curve
25
- BATCH_SIZE = 32 # Batch size for training and evaluation
26
- LEARNING_RATE = 1e-4 # Base learning rate
27
- MAX_LR = 1e-4 # Max LR for OneCycleLR scheduler
28
- GRAD_CLIP = 5.0 # Gradient clipping threshold
29
-
30
- # ========== Freezing Strategy ==========
31
- FREEZE_ENCODER = True
32
- FREEZE_EMBEDDER = True
33
- FREEZE_HEAD = False # just unfreeze the last forecasting head for finetuning
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/evaluation/__init__.py DELETED
@@ -1 +0,0 @@
1
- # __init__
 
 
transformer_model/scripts/evaluation/evaluate.py DELETED
@@ -1,144 +0,0 @@
1
- # evaluate.py
2
-
3
- import json
4
- import logging
5
- import os
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- from momentfm.utils.utils import control_randomness
11
- from sklearn.metrics import mean_squared_error, r2_score
12
- from tqdm import tqdm
13
-
14
- from transformer_model.scripts.config_transformer import (DATA_PATH,
15
- FORECAST_HORIZON,
16
- RESULTS_DIR, SEQ_LEN)
17
- from transformer_model.scripts.utils.check_device import check_device
18
- from transformer_model.scripts.utils.informer_dataset_class import \
19
- InformerDataset
20
- from transformer_model.scripts.utils.load_final_model import \
21
- load_final_transformer_model
22
-
23
- # Setup logging
24
- logging.basicConfig(
25
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
26
- )
27
-
28
-
29
- def evaluate():
30
- control_randomness(seed=13)
31
- # Set device
32
- device, backend, scaler = check_device()
33
- logging.info(f"Evaluation is running on: {backend} ({device})")
34
-
35
- # Load final model
36
- model, _ = load_final_transformer_model(device)
37
-
38
- # Recreate training dataset to get the fitted scaler
39
- train_dataset = InformerDataset(
40
- data_split="train", random_seed=13, forecast_horizon=FORECAST_HORIZON
41
- )
42
-
43
- # Use its scaler in the test dataset
44
- test_dataset = InformerDataset(
45
- data_split="test", random_seed=13, forecast_horizon=FORECAST_HORIZON
46
- )
47
-
48
- test_dataset.scaler = train_dataset.scaler
49
-
50
- test_loader = torch.utils.data.DataLoader(
51
- test_dataset, batch_size=32, shuffle=False
52
- )
53
-
54
- trues, preds = [], []
55
-
56
- with torch.no_grad():
57
- for timeseries, forecast, input_mask in tqdm(
58
- test_loader, desc="Evaluating on test set"
59
- ):
60
- timeseries = timeseries.float().to(device)
61
- forecast = forecast.float().to(device)
62
- input_mask = input_mask.to(device) # <- wichtig!
63
-
64
- output = model(x_enc=timeseries, input_mask=input_mask)
65
-
66
- trues.append(forecast.cpu().numpy())
67
- preds.append(output.forecast.cpu().numpy())
68
-
69
- trues = np.concatenate(trues, axis=0)
70
- preds = np.concatenate(preds, axis=0)
71
-
72
- # Extract only first feature (consumption)
73
- true_values = trues[:, 0, :]
74
- pred_values = preds[:, 0, :]
75
-
76
- # Inverse normalization
77
- n_features = test_dataset.n_channels
78
- true_reshaped = np.column_stack(
79
- [true_values.flatten()]
80
- + [np.zeros_like(true_values.flatten())] * (n_features - 1)
81
- )
82
- pred_reshaped = np.column_stack(
83
- [pred_values.flatten()]
84
- + [np.zeros_like(pred_values.flatten())] * (n_features - 1)
85
- )
86
-
87
- true_original = test_dataset.scaler.inverse_transform(true_reshaped)[:, 0]
88
- pred_original = test_dataset.scaler.inverse_transform(pred_reshaped)[:, 0]
89
-
90
- # Build timestamp index, since date got cutted out in informerdataset we need original dataset and use the index of the beginning of testdata to get the date
91
- csv_path = os.path.join(DATA_PATH)
92
- df = pd.read_csv(csv_path, parse_dates=["date"])
93
-
94
- train_len = len(train_dataset)
95
- test_start_idx = train_len + SEQ_LEN
96
- start_timestamp = df["date"].iloc[test_start_idx]
97
- logging.info(f"[DEBUG] timestamp: {start_timestamp}")
98
-
99
- timestamps = [
100
- start_timestamp + pd.Timedelta(hours=i) for i in range(len(true_original))
101
- ]
102
-
103
- df = pd.DataFrame(
104
- {
105
- "Timestamp": timestamps,
106
- "True Consumption (MW)": true_original,
107
- "Predicted Consumption (MW)": pred_original,
108
- }
109
- )
110
-
111
- # Save results to CSV
112
- os.makedirs(RESULTS_DIR, exist_ok=True)
113
- results_path = os.path.join(RESULTS_DIR, "test_results.csv")
114
- df.to_csv(results_path, index=False)
115
- logging.info(f"Saved prediction results to: {results_path}")
116
-
117
- # Evaluation metrics
118
- mse = mean_squared_error(
119
- df["True Consumption (MW)"], df["Predicted Consumption (MW)"]
120
- )
121
- rmse = np.sqrt(mse)
122
- mape = (
123
- np.mean(
124
- np.abs(
125
- (df["True Consumption (MW)"] - df["Predicted Consumption (MW)"])
126
- / df["True Consumption (MW)"]
127
- )
128
- )
129
- * 100
130
- )
131
- r2 = r2_score(df["True Consumption (MW)"], df["Predicted Consumption (MW)"])
132
-
133
- # Save metrics to JSON
134
- metrics = {"RMSE": float(rmse), "MAPE": float(mape), "R2": float(r2)}
135
- metrics_path = os.path.join(RESULTS_DIR, "evaluation_metrics.json")
136
- with open(metrics_path, "w") as f:
137
- json.dump(metrics, f)
138
-
139
- logging.info(f"Saved evaluation metrics to: {metrics_path}")
140
- logging.info(f"RMSE: {rmse:.3f} | MAPE: {mape:.2f}% | R²: {r2:.3f}")
141
-
142
-
143
- if __name__ == "__main__":
144
- evaluate()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/evaluation/plot_metrics.py DELETED
@@ -1,106 +0,0 @@
1
- # plot_metrics.py
2
-
3
- import json
4
- import os
5
-
6
- import matplotlib.pyplot as plt
7
- import pandas as pd
8
-
9
- from transformer_model.scripts.config_transformer import RESULTS_DIR
10
-
11
- # === Plot 1: Training Metrics ===
12
-
13
- # Load training metrics
14
- training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
15
- with open(training_metrics_path, "r") as f:
16
- metrics = json.load(f)
17
-
18
- train_losses = metrics["train_losses"]
19
- test_mses = metrics["test_mses"]
20
- test_maes = metrics["test_maes"]
21
-
22
- plt.figure(figsize=(10, 6))
23
- plt.plot(
24
- range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue"
25
- )
26
- plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red")
27
- plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green")
28
- plt.xlabel("Epoch")
29
- plt.ylabel("Loss / Metric")
30
- plt.title("Training Loss vs Test Metrics")
31
- plt.legend()
32
- plt.grid(True)
33
-
34
- plot_path = os.path.join(RESULTS_DIR, "training_plot.png")
35
- plt.savefig(plot_path)
36
- print(f"[Saved] Training metrics plot: {plot_path}")
37
- plt.show()
38
-
39
-
40
- # === Plot 2: Predictions vs Ground Truth (Full Range) ===
41
-
42
- # Load comparison results
43
- comparison_path = os.path.join(RESULTS_DIR, "test_results.csv")
44
- df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"])
45
-
46
- plt.figure(figsize=(15, 6))
47
- plt.plot(
48
- df_comparison["Timestamp"],
49
- df_comparison["True Consumption (MW)"],
50
- label="True",
51
- color="darkblue",
52
- )
53
- plt.plot(
54
- df_comparison["Timestamp"],
55
- df_comparison["Predicted Consumption (MW)"],
56
- label="Predicted",
57
- color="red",
58
- linestyle="--",
59
- )
60
- plt.title("Energy Consumption: Predictions vs Ground Truth")
61
- plt.xlabel("Time")
62
- plt.ylabel("Consumption (MW)")
63
- plt.legend()
64
- plt.grid(True)
65
- plt.tight_layout()
66
-
67
- plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png")
68
- plt.savefig(plot_path)
69
- print(f"[Saved] Full range comparison plot: {plot_path}")
70
- plt.show()
71
-
72
-
73
- # === Plot 3: Predictions vs Ground Truth (First Month) ===
74
-
75
- first_month_start = df_comparison["Timestamp"].min()
76
- first_month_end = first_month_start + pd.Timedelta(days=25)
77
- df_first_month = df_comparison[
78
- (df_comparison["Timestamp"] >= first_month_start)
79
- & (df_comparison["Timestamp"] <= first_month_end)
80
- ]
81
-
82
- plt.figure(figsize=(15, 6))
83
- plt.plot(
84
- df_first_month["Timestamp"],
85
- df_first_month["True Consumption (MW)"],
86
- label="True",
87
- color="darkblue",
88
- )
89
- plt.plot(
90
- df_first_month["Timestamp"],
91
- df_first_month["Predicted Consumption (MW)"],
92
- label="Predicted",
93
- color="red",
94
- linestyle="--",
95
- )
96
- plt.title("Energy Consumption (First Month): Predictions vs Ground Truth")
97
- plt.xlabel("Time")
98
- plt.ylabel("Consumption (MW)")
99
- plt.legend()
100
- plt.grid(True)
101
- plt.tight_layout()
102
-
103
- plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png")
104
- plt.savefig(plot_path)
105
- print(f"[Saved] 1-Month comparison plot: {plot_path}")
106
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/training/__init__.py DELETED
@@ -1 +0,0 @@
1
- # __init__
 
 
transformer_model/scripts/training/load_basis_model.py DELETED
@@ -1,69 +0,0 @@
1
- # load_basis_model.py
2
- # Load and initialize the base MOMENT model before finetuning
3
-
4
- import logging
5
-
6
- import torch
7
- from momentfm import MOMENTPipeline
8
-
9
- from transformer_model.scripts.config_transformer import (FORECAST_HORIZON,
10
- FREEZE_EMBEDDER,
11
- FREEZE_ENCODER,
12
- FREEZE_HEAD,
13
- HEAD_DROPOUT,
14
- SEQ_LEN,
15
- WEIGHT_DECAY)
16
-
17
- # Setup logging
18
- logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
19
-
20
-
21
- def load_moment_model():
22
- """
23
- Loads and configures the MOMENT model for forecasting.
24
- """
25
- logging.info("Loading MOMENT model...")
26
- model = MOMENTPipeline.from_pretrained(
27
- "AutonLab/MOMENT-1-large",
28
- model_kwargs={
29
- "task_name": "forecasting",
30
- "forecast_horizon": FORECAST_HORIZON, # default = 1
31
- "head_dropout": HEAD_DROPOUT, # default = 0.1
32
- "weight_decay": WEIGHT_DECAY, # default = 0.0
33
- "freeze_encoder": FREEZE_ENCODER, # default = True
34
- "freeze_embedder": FREEZE_EMBEDDER, # default = True
35
- "freeze_head": FREEZE_HEAD, # default = False
36
- },
37
- )
38
-
39
- model.init()
40
- logging.info("Model initialized successfully.")
41
- return model
42
-
43
-
44
- def print_trainable_params(model):
45
- """
46
- Logs all trainable (unfrozen) parameters of the model.
47
- """
48
- logging.info("Unfrozen parameters:")
49
- for name, param in model.named_parameters():
50
- if param.requires_grad:
51
- logging.info(f" {name}")
52
-
53
-
54
- def test_dummy_forward(model):
55
- """
56
- Performs a dummy forward pass to verify the model runs without error.
57
- """
58
- logging.info(
59
- "Running dummy forward pass with random tensors to see if model is running."
60
- )
61
- dummy_x = torch.randn(16, 1, SEQ_LEN)
62
- output = model(x_enc=dummy_x)
63
- logging.info(f"Dummy forward pass successful.Output shape: {output.shape}")
64
-
65
-
66
- if __name__ == "__main__":
67
- model = load_moment_model()
68
- print_trainable_params(model)
69
- test_dummy_forward(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/training/train.py DELETED
@@ -1,199 +0,0 @@
1
- # train.py
2
-
3
- import json
4
- import logging
5
- import os
6
- import time
7
-
8
- import numpy as np
9
- import torch
10
- from momentfm.utils.utils import control_randomness
11
- from sklearn.metrics import mean_absolute_error, mean_squared_error
12
- from tqdm import tqdm
13
-
14
- from transformer_model.scripts.config_transformer import (CHECKPOINT_DIR,
15
- GRAD_CLIP,
16
- LEARNING_RATE,
17
- MAX_EPOCHS, MAX_LR,
18
- RESULTS_DIR)
19
- from transformer_model.scripts.training.load_basis_model import \
20
- load_moment_model
21
- from transformer_model.scripts.utils.check_device import check_device
22
- from transformer_model.scripts.utils.create_dataloaders import \
23
- create_dataloaders
24
-
25
- # === Setup logging ===
26
- logging.basicConfig(
27
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
- )
29
-
30
-
31
- def train():
32
- # Start timing
33
- start_time = time.time()
34
-
35
- # Setup device (CUDA / DirectML / CPU) and AMP scaler
36
- device, backend, scaler = check_device()
37
-
38
- # Load base model
39
- model = load_moment_model().to(device)
40
-
41
- # Set random seeds for reproducibility
42
- control_randomness(seed=13)
43
-
44
- # Setup loss function and optimizer
45
- criterion = torch.nn.MSELoss().to(device)
46
- optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
47
-
48
- # Load data
49
- train_loader, test_loader = create_dataloaders()
50
-
51
- # Setup learning rate scheduler (OneCycle policy)
52
- total_steps = len(train_loader) * MAX_EPOCHS
53
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
54
- optimizer, max_lr=MAX_LR, total_steps=total_steps, pct_start=0.3
55
- )
56
-
57
- # Ensure output folders exist
58
- os.makedirs(CHECKPOINT_DIR, exist_ok=True)
59
- os.makedirs(RESULTS_DIR, exist_ok=True)
60
-
61
- # Store metrics
62
- train_losses, test_mses, test_maes = [], [], []
63
-
64
- best_mae = float("inf")
65
- best_epoch = None
66
- no_improve_epochs = 0
67
- patience = 5
68
-
69
- for epoch in range(MAX_EPOCHS):
70
- model.train()
71
- epoch_losses = []
72
-
73
- for timeseries, forecast, input_mask in tqdm(
74
- train_loader, desc=f"Epoch {epoch}"
75
- ):
76
- timeseries = timeseries.float().to(device)
77
- input_mask = input_mask.to(device)
78
- forecast = forecast.float().to(device)
79
-
80
- # Zero gradients
81
- optimizer.zero_grad(set_to_none=True)
82
-
83
- # Forward pass (with AMP if enabled)
84
- if scaler:
85
- with torch.amp.autocast(device_type="cuda"):
86
- output = model(x_enc=timeseries, input_mask=input_mask)
87
- loss = criterion(output.forecast, forecast)
88
- else:
89
- output = model(x_enc=timeseries, input_mask=input_mask)
90
- loss = criterion(output.forecast, forecast)
91
-
92
- # Backward pass + optimization
93
- if scaler:
94
- scaler.scale(loss).backward()
95
- scaler.unscale_(optimizer)
96
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
97
- scaler.step(optimizer)
98
- scaler.update()
99
- else:
100
- loss.backward()
101
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
102
- optimizer.step()
103
-
104
- epoch_losses.append(loss.item())
105
-
106
- average_train_loss = np.mean(epoch_losses)
107
- train_losses.append(average_train_loss)
108
- logging.info(f"Epoch {epoch}: Train Loss = {average_train_loss:.4f}")
109
-
110
- # === Evaluation ===
111
- model.eval()
112
- trues, preds = [], []
113
-
114
- with torch.no_grad():
115
- for timeseries, forecast, input_mask in test_loader:
116
- timeseries = timeseries.float().to(device)
117
- input_mask = input_mask.to(device)
118
- forecast = forecast.float().to(device)
119
-
120
- if scaler:
121
- with torch.amp.autocast(device_type="cuda"):
122
- output = model(x_enc=timeseries, input_mask=input_mask)
123
- else:
124
- output = model(x_enc=timeseries, input_mask=input_mask)
125
-
126
- trues.append(forecast.detach().cpu().numpy())
127
- preds.append(output.forecast.detach().cpu().numpy())
128
-
129
- trues = np.concatenate(trues, axis=0)
130
- preds = np.concatenate(preds, axis=0)
131
-
132
- # Reshape for sklearn metrics
133
- trues_2d = trues.reshape(trues.shape[0], -1)
134
- preds_2d = preds.reshape(preds.shape[0], -1)
135
-
136
- mse = mean_squared_error(trues_2d, preds_2d)
137
- mae = mean_absolute_error(trues_2d, preds_2d)
138
-
139
- test_mses.append(mse)
140
- test_maes.append(mae)
141
- logging.info(f"Epoch {epoch}: Test MSE = {mse:.4f}, MAE = {mae:.4f}")
142
-
143
- # === Early Stopping Check ===
144
- if mae < best_mae:
145
- best_mae = mae
146
- best_epoch = epoch
147
- no_improve_epochs = 0
148
-
149
- # Save best model
150
- best_model_path = os.path.join(CHECKPOINT_DIR, "best_model.pth")
151
- torch.save(model.state_dict(), best_model_path)
152
- logging.info(
153
- f"New best model saved to: {best_model_path} (MAE: {best_mae:.4f})"
154
- )
155
- else:
156
- no_improve_epochs += 1
157
- logging.info(f"No improvement in MAE for {no_improve_epochs} epoch(s).")
158
-
159
- if no_improve_epochs >= patience:
160
- logging.info("Early stopping triggered.")
161
- break
162
-
163
- # Save checkpoint
164
- checkpoint_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch}.pth")
165
- torch.save(model.state_dict(), checkpoint_path)
166
-
167
- scheduler.step()
168
-
169
- logging.info(f"Best model was at epoch {best_epoch} with MAE: {best_mae:.4f}")
170
-
171
- # Save final model
172
- final_model_path = os.path.join(CHECKPOINT_DIR, "model_final.pth")
173
- torch.save(model.state_dict(), final_model_path)
174
- logging.info(f"Final model saved to: {final_model_path}")
175
- logging.info(f"Final Test MSE: {test_mses[-1]:.4f}, MAE: {test_maes[-1]:.4f}")
176
-
177
- # Save training metrics
178
- metrics = {
179
- "train_losses": [float(x) for x in train_losses],
180
- "test_mses": [float(x) for x in test_mses],
181
- "test_maes": [float(x) for x in test_maes],
182
- }
183
-
184
- metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
185
- with open(metrics_path, "w") as f:
186
- json.dump(metrics, f)
187
- logging.info(f"Training metrics saved to: {metrics_path}")
188
-
189
- # Done
190
- elapsed = time.time() - start_time
191
- logging.info(f"Training complete in {elapsed / 60:.2f} minutes.")
192
-
193
-
194
- # === Entry Point ===
195
- if __name__ == "__main__":
196
- try:
197
- train()
198
- except Exception as e:
199
- logging.error(f"Training failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/utils/__init__.py DELETED
@@ -1 +0,0 @@
1
- # __init__
 
 
transformer_model/scripts/utils/check_device.py DELETED
@@ -1,55 +0,0 @@
1
- import importlib
2
- import subprocess
3
- import sys
4
-
5
- import torch
6
-
7
-
8
- def install_package(package_name):
9
- subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
10
-
11
-
12
- def check_device():
13
- # **Check for NVIDIA GPU (CUDA)**
14
- if torch.cuda.is_available():
15
- device = torch.device("cuda") # Use NVIDIA GPU
16
- backend = "CUDA (NVIDIA)"
17
- mixed_precision = True # Use Automatic Mixed Precision (AMP)
18
-
19
- # **If no NVIDIA GPU, check for AMD GPU (DirectML) only in Windows**
20
- else:
21
- try:
22
- # Only try DirectML if the environment is Windows and DirectML is installed
23
- if "win32" in sys.platform:
24
- torch_directml = importlib.import_module("torch_directml")
25
- if torch_directml.device_count() > 0:
26
- device = torch_directml.device() # Use AMD GPU with DirectML
27
- backend = "DirectML (AMD)"
28
- mixed_precision = False # No AMP for AMD GPU
29
- else:
30
- raise ImportError # AMD GPU not found
31
- else:
32
- device = torch.device("cpu")
33
- backend = "CPU"
34
- mixed_precision = False # No AMP for CPU
35
-
36
- except ImportError:
37
- # If DirectML is not installed or AMD GPU not found
38
- device = torch.device("cpu")
39
- backend = "CPU"
40
- mixed_precision = False # No AMP for CPU
41
-
42
- # Print the chosen device info
43
- print(f"Training is running on: {backend} ({device})")
44
-
45
- # **Initialize scaler (only for NVIDIA)**
46
- if mixed_precision:
47
- scaler = torch.amp.GradScaler()
48
- else:
49
- scaler = None # No scaler needed for AMD/CPU
50
-
51
- return device, backend, scaler
52
-
53
-
54
- if __name__ == "__main__":
55
- device, backend, scaler = check_device()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/utils/create_dataloaders.py DELETED
@@ -1,46 +0,0 @@
1
- # create_dataloaders.py
2
-
3
- import logging
4
-
5
- from momentfm.utils.utils import control_randomness
6
- from torch.utils.data import DataLoader
7
-
8
- from transformer_model.scripts.config_transformer import (BATCH_SIZE,
9
- FORECAST_HORIZON)
10
- from transformer_model.scripts.utils.informer_dataset_class import \
11
- InformerDataset
12
-
13
-
14
- def create_dataloaders():
15
- logging.info("Setting random seeds...")
16
- control_randomness(seed=13)
17
-
18
- logging.info("Loading training dataset...")
19
- train_dataset = InformerDataset(
20
- data_split="train", random_seed=13, forecast_horizon=FORECAST_HORIZON
21
- )
22
- logging.info(
23
- "Train set loaded — Samples: %d | Features: %d",
24
- len(train_dataset),
25
- train_dataset.n_channels,
26
- )
27
-
28
- logging.info("Loading test dataset...")
29
- test_dataset = InformerDataset(
30
- data_split="test", random_seed=13, forecast_horizon=FORECAST_HORIZON
31
- )
32
- logging.info(
33
- "Test set loaded — Samples: %d | Features: %d",
34
- len(test_dataset),
35
- test_dataset.n_channels,
36
- )
37
-
38
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
39
- test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
40
-
41
- logging.info("Dataloaders created successfully.")
42
- return train_loader, test_loader
43
-
44
-
45
- if __name__ == "__main__":
46
- create_dataloaders()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/utils/informer_dataset_class.py DELETED
@@ -1,123 +0,0 @@
1
- # informer_dataset.py
2
-
3
- import logging
4
- from typing import Optional
5
-
6
- import numpy as np
7
- import pandas as pd
8
- from sklearn.preprocessing import StandardScaler
9
-
10
- from transformer_model.scripts.config_transformer import DATA_PATH, SEQ_LEN
11
-
12
- logging.basicConfig(level=logging.INFO)
13
-
14
-
15
- class InformerDataset:
16
- def __init__(
17
- self,
18
- forecast_horizon: Optional[int],
19
- data_split: str = "train",
20
- data_stride_len: int = 1,
21
- task_name: str = "forecasting",
22
- random_seed: int = 42,
23
- ):
24
- """
25
- Parameters
26
- ----------
27
- forecast_horizon : int
28
- Length of the prediction sequence.
29
- data_split : str
30
- 'train' or 'test'.
31
- data_stride_len : int
32
- Stride length between time windows.
33
- task_name : str
34
- 'forecasting' or 'imputation'.
35
- random_seed : int
36
- For reproducibility.
37
- """
38
-
39
- self.seq_len = SEQ_LEN
40
- self.forecast_horizon = forecast_horizon
41
- self.full_file_path_and_name = DATA_PATH
42
- self.data_split = data_split
43
- self.data_stride_len = data_stride_len
44
- self.task_name = task_name
45
- self.random_seed = random_seed
46
-
47
- self._read_data()
48
-
49
- def _get_borders(self):
50
- train_ratio = 0.7
51
- n_train = int(self.length_timeseries_original * train_ratio)
52
- n_test = self.length_timeseries_original - n_train
53
-
54
- train_end = n_train
55
- test_start = train_end - self.seq_len
56
- test_end = test_start + n_test + self.seq_len
57
-
58
- # logging.info(f"Train range: 0 to {train_end}")
59
- # logging.info(f"Test range: {test_start} to {test_end}")
60
-
61
- return slice(0, train_end), slice(test_start, test_end)
62
-
63
- def _read_data(self):
64
- self.scaler = StandardScaler()
65
-
66
- df = pd.read_csv(self.full_file_path_and_name)
67
- self.length_timeseries_original = df.shape[0]
68
- self.n_channels = df.shape[1] - 1 # exclude timestamp column
69
-
70
- df.drop(columns=["date"], inplace=True)
71
- df = df.infer_objects(copy=False).interpolate(method="cubic")
72
-
73
- data_splits = self._get_borders()
74
- train_data = df[data_splits[0]]
75
-
76
- self.scaler.fit(train_data.values)
77
- df = self.scaler.transform(df.values)
78
-
79
- if self.data_split == "train":
80
- self.data = df[data_splits[0], :]
81
- elif self.data_split == "test":
82
- self.data = df[data_splits[1], :]
83
-
84
- self.length_timeseries = self.data.shape[0]
85
-
86
- # logging.info(f"{self.data_split.capitalize()} set loaded.")
87
- # logging.info(f"Time series length: {self.length_timeseries}")
88
- # logging.info(f"Number of features: {self.n_channels}")
89
-
90
- def __getitem__(self, index):
91
- seq_start = self.data_stride_len * index
92
- seq_end = seq_start + self.seq_len
93
- input_mask = np.ones(self.seq_len)
94
-
95
- if self.task_name == "forecasting":
96
- pred_end = seq_end + self.forecast_horizon
97
-
98
- if pred_end > self.length_timeseries:
99
- pred_end = self.length_timeseries
100
- seq_end = seq_end - self.forecast_horizon
101
- seq_start = seq_end - self.seq_len
102
-
103
- timeseries = self.data[seq_start:seq_end, :].T
104
- forecast = self.data[seq_end:pred_end, :].T
105
-
106
- return timeseries, forecast, input_mask
107
-
108
- elif self.task_name == "imputation":
109
- if seq_end > self.length_timeseries:
110
- seq_end = self.length_timeseries
111
- seq_end = seq_end - self.seq_len
112
-
113
- timeseries = self.data[seq_start:seq_end, :].T
114
-
115
- return timeseries, input_mask
116
-
117
- def __len__(self):
118
- if self.task_name == "imputation":
119
- return (self.length_timeseries - self.seq_len) // self.data_stride_len + 1
120
- elif self.task_name == "forecasting":
121
- return (
122
- self.length_timeseries - self.seq_len - self.forecast_horizon
123
- ) // self.data_stride_len + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/utils/load_final_model.py DELETED
@@ -1,39 +0,0 @@
1
- import logging
2
- import os
3
-
4
- import torch
5
- from huggingface_hub import hf_hub_download
6
-
7
- from transformer_model.scripts.config_transformer import CHECKPOINT_DIR
8
- from transformer_model.scripts.training.load_basis_model import \
9
- load_moment_model
10
-
11
- logging.basicConfig(level=logging.INFO)
12
-
13
-
14
- # load model from checkpoint if available, else download it from hugging face
15
- def load_real_transformer_model(device=None): # ⬅️ Name geändert
16
- if device is None:
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- model = load_moment_model()
20
- filename = "model_final.pth"
21
- local_path = os.path.join(CHECKPOINT_DIR, filename)
22
-
23
- if os.path.exists(local_path):
24
- checkpoint_path = local_path
25
- print("Loading model from local path...")
26
- else:
27
- print("Downloading model from Hugging Face Hub...")
28
- checkpoint_path = hf_hub_download(
29
- repo_id="dlaj/energy-forecasting-files", # passe ggf. an
30
- filename=f"transformer_model/{filename}",
31
- repo_type="dataset",
32
- )
33
-
34
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
35
- model.to(device)
36
- model.eval()
37
- logging.info(f"Model loaded from: {checkpoint_path}")
38
-
39
- return model, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer_model/scripts/utils/model_loader_wrapper.py DELETED
@@ -1,41 +0,0 @@
1
- from scripts.utils.env import use_dummy
2
- from transformer_model.scripts.config_transformer import FORECAST_HORIZON
3
- from transformer_model.scripts.utils.informer_dataset_class import \
4
- InformerDataset
5
- from transformer_model.scripts.utils.load_final_model import \
6
- load_real_transformer_model
7
-
8
- try:
9
- from scripts.utils.dummy import DummyDataset, DummyTransformerModel
10
- except ImportError:
11
- DummyTransformerModel = None
12
- DummyDataset = None
13
-
14
-
15
- def load_final_transformer_model():
16
- if use_dummy():
17
- if DummyTransformerModel is None:
18
- raise ImportError("DummyTransformerModel not available")
19
- return DummyTransformerModel(), "cpu"
20
- else:
21
- return load_real_transformer_model()
22
-
23
-
24
- def load_model_and_dataset():
25
- model, device = load_final_transformer_model()
26
-
27
- if use_dummy():
28
- if DummyDataset is None:
29
- raise ImportError("DummyDataset not available")
30
- dataset = DummyDataset(length=200)
31
- else:
32
- train_dataset = InformerDataset(
33
- data_split="train", random_seed=13, forecast_horizon=FORECAST_HORIZON
34
- )
35
- test_dataset = InformerDataset(
36
- data_split="test", random_seed=13, forecast_horizon=FORECAST_HORIZON
37
- )
38
- test_dataset.scaler = train_dataset.scaler
39
- dataset = test_dataset
40
-
41
- return model, dataset, device