wxformer_6h / finetune_final /model_single_cached.yml
djgagne's picture
Added single step and finetune weights
8c9de0c
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements 6 hourly state-in-state-out crossformer
# on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
# The model is trained on 6 hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/ksha/CREDIT_runs/wxformer_6h/'
seed: 1000
data:
# upper-air variables
variables: ['U','V','T','Q']
save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
# surface variables
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
# dynamic forcing variables
dynamic_forcing_variables: ['tsi']
save_loc_dynamic_forcing: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
# static variables
static_variables: ['Z_GDS4_SFC','LSM']
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
# mean / std path
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
# train / validation split
train_years: [1979, 2018]
valid_years: [2018, 2019]
# data workflow
scaler_type: 'std_cached'
# state-in-state-out
history_len: 1
valid_history_len: 1
forecast_len: 0
valid_forecast_len: 0
one_shot: True
# 1 for hourly model
lead_time_periods: 6
# do not use skip_period
skip_periods: null
# compatible with the old 'std'
static_first: True
trainer:
type: standard # <---------- change to your type
mode: fsdp
cpu_offload: False
activation_checkpoint: True
load_weights: True
load_optimizer: True
load_scaler: True
load_sheduler: True
skip_validation: False
update_learning_rate: False
save_backup_weights: True
save_best_weights: True
learning_rate: 1.0e-03 # <-- change to your lr
weight_decay: 0
train_batch_size: 1
valid_batch_size: 1
batches_per_epoch: 0
valid_batches_per_epoch: 0
stopping_patience: 999
start_epoch: 0
num_epoch: 6
reload_epoch: True
epochs: &epochs 70
use_scheduler: True
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
# Automatic Mixed Precision: False
amp: False
# rescale loss as loss = loss / grad_accum_every
grad_accum_every: 1
# gradient clipping
grad_max_norm: 1.0
# number of workers
thread_workers: 4
valid_thread_workers: 0
model:
# crossformer example
type: "crossformer"
frames: 1 # number of input states (default: 1)
image_height: 640 # number of latitude grids (default: 640)
image_width: 1280 # number of longitude grids (default: 1280)
levels: 16 # number of upper-air variable levels (default: 15)
channels: 4 # upper-air variable channels
surface_channels: 7 # surface variable channels
input_only_channels: 3 # dynamic forcing, forcing, static channels
output_only_channels: 0 # diagnostic variable channels
patch_width: 1 # number of latitude grids in each 3D patch (default: 1)
patch_height: 1 # number of longitude grids in each 3D patch (default: 1)
frame_patch_size: 1 # number of input states in each 3D patch (default: 1)
dim: [128, 256, 512, 1024] # Dimensionality of each layer
depth: [2, 2, 8, 2] # Depth of each layer
global_window_size: [10, 5, 2, 1] # Global window size for each layer
local_window_size: 10 # Local window size
cross_embed_kernel_sizes: # kernel sizes for cross-embedding
- [4, 8, 16, 32]
- [2, 4]
- [2, 4]
- [2, 4]
cross_embed_strides: [2, 2, 2, 2] # Strides for cross-embedding (default: [4, 2, 2, 2])
attn_dropout: 0. # Dropout probability for attention layers (default: 0.0)
ff_dropout: 0. # Dropout probability for feed-forward layers (default: 0.0)
# map boundary padding
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
loss:
# the main training loss
training_loss: "mse"
# power loss (x), spectral_loss (x)
use_power_loss: False
use_spectral_loss: False
# use latitude weighting
use_latitude_weights: True
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
# turn-off variable weighting
use_variable_weights: False
predict:
forecasts:
type: "custom" # keep it as "custom"
start_year: 2020 # year of the first initialization (where rollout will start)
start_month: 1 # month of the first initialization
start_day: 1 # day of the first initialization
start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
duration: 30 # number of days to initialize, starting from the (year, mon, day) above
# duration should be divisible by the number of GPUs
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
days: 2 # forecast lead time as days (1 means 24-hour forecast)
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/wxformer_6h/'
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
# turn-off low-pass filter
use_laplace_filter: False
# deprecated
# save_format: "nc"
pbs: #derecho
conda: "/glade/work/ksha/miniconda3/envs/credit"
project: "NAML0001"
job_name: "wxformer_6h"
walltime: "12:00:00"
nodes: 8
ncpus: 64
ngpus: 4
mem: '480GB'
queue: 'main'