Skip to content

Forecasting

Sacha Lewin edited this page Sep 24, 2025 · 5 revisions

This page explains how to generate forecasts using a trained denoiser.

The script is located at experiments/diffusion/forecast.py.

For forecasting, we do not compose blankets as we use an autoregressive approach. We still provide an all-at-once forecasting scripts with blankets in forecast_aao.py, but we did not include it in the paper due to its inferior performance because of the few known states to condition on, compared to the size of the trajectory, which significantly limits information flow during denoising.

Autoregressive forecasting

There are three types of forecasts we consider. They all differ in how the first initial autoregressive step is performed.

  • From Full states: The first step conditions on $n$ full latent states.
  • From reanalysis: The first step performs reanalysis on a full blanket, then the last $n$ states are used for the autoregressive rollout.
  • From observations: The first step simultaneously assimilates the observations on $n$ states and predicts the next ones. Next step are autoregressive based on previous latent generated states.

Each forecast is performed by splitting the generated trajectory into two parts. The left part contains conditioning states, while the right part, after generation, contains the forecasted states. This is centered, so a window of 6 states conditioned on 2 states has these states located at indices 1 and 2 (2nd and 3rd states).

Configuration file

model_path: /path/to/project/autoencoders/your_ae/1/latents/wiki/denoisers/your_denoiser/0
model_target: last  # Options: best, last

diffusion:
  num_steps: 32  # Defaults to model's validation denoising steps
  mmps_iters: 2
  sampler:
    type: pc
    config:
      corrections: 2
      delta: 0.1

assimilation_lengths: "3"
lead_times: "24"

# If one of them is auto, will be equal to blanket_size minus the other
preds_per_step: 3  # States to predict per AR step.
past_window_size: auto  # Size of the sliding window of past GT/generated states

# Samples for a given window size
num_samples_per_date: 5
start_dates:
  - "2000-03-10 0h"
  - "2000-03-20 0h"
  - "2000-04-10 0h"
  - "2000-04-20 0h"

# Initialization mode.
#   full: use full encoded latent states
#   observations: perform partial reanalysis from observations: [X X X X obs obs | forecast forecast ...]
#   reanalysis: load the last n states from reanalysis performed before. TODO: do reanalysis from this script automatically & shift.
initialization: full

# If observing masked pixel states first.
observed_variables:
  stations:
    enabled: true
    low: 0
    high: 1  # included
  satellites:
    enabled: true
    low: 2
    high: 3  # included
masks:
  - name: leo
    type: satellite
    covariance: 1e-2
    config:
      orbital_altitude: 800
      inclination: 75
      initial_phase: 0
      obs_freq: 60
      fov: 5
  - name: weather11k
    type: stations
    covariance: 1e-4
    config:
      num_stations: 11k
      num_valid_stations: 0

precision: float16  # Options: float32, float16, bfloat16, null = model training precision

hardware:
  backend: slurm  # or async
  account: your_account
  gen:
    cpus: 8
    ram: 60GB
    time: "10:00"
    partition: your_partition
  aggregate:
    cpus: 4
    ram: 60GB
    time: "5:00"
    partition: your_partition
  • model_path: Absolute path of your model, including the lap, i.e., finishes in .../model_id/lap.
  • model_target: Either use the best model checkpoint, according to validation loss, or the latest one saved.
  • diffusion: Diffusion settings. If set to null, it will use the number of validation steps in the training configuration. Available samplers are pc, ddpm, ddim, rewind, and lms. More information can be found in the Samplers page.
  • assimilation_length: The number of initial ground-truth states/observations to assimilate.
  • lead_time: The number of states to forecast.
  • preds_per_step: This is the number of states to keep after each autoregressive rollouts. The number of diffusion generations done for forecasting your full trajectory is lead_time/preds_per_step.
  • past_window_size: As explained above, we split the window in two over time for condition and forecast. This setting defines the number of steps used for forecasting. This should be larger or equal to assimilation_length.
  • Each inference script follows the literature and generates ensembles of trajectories. Make sure these dates are included in your total ERA5 data, as the autoencoder takes in context and timestamps, that can be loaded from there. Ground truth will also be recovered to evaluate and render. Throughout the conditional inference, we will choose 4 ensembles of size 5 to have meaningful probabilistic evaluation. In the paper, we used 4 ensembles of size 10.
  • initialization: This defines which mode to use, as explained above.
  • observed_variables: This setting only concerns reanalysis and observations initializations. It defines which variables are observed through the satellite masks and the stations masks. Note that high indices are included here!
  • masks: This defines the synthetic masks to use for masking ground-truth data into observations. These settings are the default ones used in the paper.
  • precision: Specifies the inference precision, either float16, bfloat16, or float32. We did not notice major differences, although we recommend training in float32, more importantly.
  • Hardware settings follow other scripts. Do not forget to put quotes around times in minutes, or they might be interpreted as hours by PyYAML.

Once configured, you can run python forecast.py. You can, again, set an ID with +id=my_id. For the rest of the wiki, we will consider that you generated forecasts for all three initializations mode and use the IDs: wiki_full, wiki_rea, and wiki_obs.

Next step

We recommend you to render your states and evaluate them after inference, to see if your results make sense.

Next, you can move on to reanalysis.

Clone this wiki locally