Skip to content
Snippets Groups Projects
Commit 12861881 authored by Aaron Spring's avatar Aaron Spring :baby_symbol:
Browse files

Update scripts.py, clip(-10, 1) and penalize

parent 9228d336
No related branches found
No related tags found
1 merge request!9Update scripts.py, clip(-10, 1) and penalize
......@@ -2,6 +2,15 @@
### unreleased
- Order of processing gridded `RPSS` to final score: (#7, !9, [s2s-ai-competition-scoring-image!2](https://renkulab.io/gitlab/tasko.olevski/s2s-ai-competition-scoring-image/-/merge_requests/2), [Aaron Spring](https://renkulab.io/gitlab/aaron.spring))
1. `RPSS`
2. penalize #7
3. `clip(-10,1)`
4. mean over `forecast_time`
5. spatially weighted mean [90N-60S]
6. mean over `lead_time` and `data_vars`
- Dont forget to `git add current_notebook.ipynb` also to ensure that consistent training pipeline and submission file are tagged, added to notebooks. (!9, [Aaron Spring](https://renkulab.io/gitlab/aaron.spring))
- Rerun [`ML_train_and_predict.ipynb`](https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/blob/master/notebooks/ML_train_and_predict.ipynb) (!9, [Aaron Spring](https://renkulab.io/gitlab/aaron.spring))
- Fix typo in safeguards in [ML_forecast_template.ipynb](https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/blob/master/notebooks/ML_forecast_template.ipynb): "We did NOT use `test` explicitly in training or implicitly in incrementally adjusting parameters."" (!8, [Aaron Spring](https://renkulab.io/gitlab/aaron.spring))
- Add notebooks showcasing accessing output of different models from different sources: (!2, [Aaron Spring](https://renkulab.io/gitlab/aaron.spring))
- S2S-Project models:
......
%% Cell type:markdown id: tags:
# Train ML model for predictions of week 3-4 & 5-6
This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/).
%% Cell type:markdown id: tags:
# Synopsis
%% Cell type:markdown id: tags:
## Method: `name`
- decription
- a few details
%% Cell type:markdown id: tags:
## Data used
Training-input for Machine Learning model:
- renku datasets, climetlab, IRIDL
Forecast-input for Machine Learning model:
- renku datasets, climetlab, IRIDL
Compare Machine Learning model forecast against ground truth:
- renku datasets, climetlab, IRIDL
%% Cell type:markdown id: tags:
## Resources used
for training, details in reproducibility
- platform: renku
- memory: 8 GB
- processors: 2 CPU
- storage required: 10 GB
%% Cell type:markdown id: tags:
## Safeguards
All points have to be [x] checked. If not, your submission is invalid.
Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.
(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)
%% Cell type:markdown id: tags:
### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1)
If the organizers suspect overfitting, your contribution can be disqualified.
- [ ] We did not use 2020 observations in training (explicit overfitting and cheating)
- [ ] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)
- [ ] We provide RPSS scores for the training period with script `skill_by_year`, see in section 6.3 `predict`.
- [ ] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).
- [ ] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.
- [ ] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.
- [ ] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)).
%% Cell type:markdown id: tags:
### Safeguards for Reproducibility
Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize
- [ ] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)
- [ ] Code is well documented, readable and reproducible.
- [ ] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train.
%% Cell type:markdown id: tags:
# Todos to improve template
This is just a demo.
- [ ] for both variables
- [ ] for both `lead_time`s
- [ ] ensure probabilistic prediction outcome with `category` dim
%% Cell type:markdown id: tags:
# Imports
%% Cell type:code id: tags:
``` python
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
import xarray as xr
xr.set_options(display_style='text')
from dask.utils import format_bytes
import xskillscore as xs
```
%% Cell type:markdown id: tags:
# Get training data
preprocessing of input data may be done in separate notebook/script
%% Cell type:markdown id: tags:
## Hindcast
get weekly initialized hindcasts
%% Cell type:code id: tags:
``` python
# consider renku datasets
#! renku storage pull path
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:markdown id: tags:
## Observations
corresponding to hindcasts
%% Cell type:code id: tags:
``` python
# consider renku datasets
#! renku storage pull path
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:markdown id: tags:
# ML model
%% Cell type:code id: tags:
``` python
bs=32
import numpy as np
class DataGenerator(keras.utils.Sequence):
def __init__(self):
"""
Data generator
Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
Args:
"""
self.on_epoch_end()
# For some weird reason calling .load() earlier messes up the mean and std computations
if load: print('Loading data into RAM'); self.data.load()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.ceil(self.n_samples / self.batch_size))
def __getitem__(self, i):
'Generate one batch of data'
idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
# got all nan if nans not masked
X = self.data.isel(time=idxs).fillna(0.).values
y = self.verif_data.isel(time=idxs).fillna(0.).values
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.idxs = np.arange(self.n_samples)
if self.shuffle == True:
np.random.shuffle(self.idxs)
```
%% Cell type:markdown id: tags:
## data prep: train, valid, test
%% Cell type:code id: tags:
``` python
# time is the forecast_reference_time
time_train_start,time_train_end='2000','2017'
time_valid_start,time_valid_end='2018','2019'
time_test = '2020'
```
%% Cell type:code id: tags:
``` python
dg_train = DataGenerator()
```
%% Cell type:code id: tags:
``` python
dg_valid = DataGenerator()
```
%% Cell type:code id: tags:
``` python
dg_test = DataGenerator()
```
%% Cell type:markdown id: tags:
## `fit`
%% Cell type:code id: tags:
``` python
cnn = keras.models.Sequential([])
```
%% Cell type:code id: tags:
``` python
cnn.summary()
```
%% Cell type:code id: tags:
``` python
cnn.compile(keras.optimizers.Adam(1e-4), 'mse')
```
%% Cell type:code id: tags:
``` python
import warnings
warnings.simplefilter("ignore")
```
%% Cell type:code id: tags:
``` python
cnn.fit(dg_train, epochs=1, validation_data=dg_valid)
```
%% Cell type:markdown id: tags:
## `predict`
Create predictions and print `mean(variable, lead_time, longitude, weighted latitude)` RPSS for all years as calculated by `skill_by_year`. For now RPS, todo: change to RPSS.
Create predictions and print `mean(variable, lead_time, longitude, weighted latitude)` RPSS for all years as calculated by `skill_by_year`.
%% Cell type:code id: tags:
``` python
from scripts import skill_by_year
```
%% Cell type:code id: tags:
``` python
def create_predictions(model, dg):
"""Create non-iterative predictions"""
preds = model.predict(dg).squeeze()
# transform
return preds
```
%% Cell type:markdown id: tags:
### `predict` training period in-sample
%% Cell type:code id: tags:
``` python
preds_is = create_predictions(cnn, dg_train)
```
%% Cell type:code id: tags:
``` python
skill_by_year(preds_is)
```
%% Cell type:markdown id: tags:
### `predict` valid out-of-sample
%% Cell type:code id: tags:
``` python
preds_os = create_predictions(cnn, dg_valid)
```
%% Cell type:code id: tags:
``` python
skill_by_year(preds_os)
```
%% Cell type:markdown id: tags:
### `predict` test
%% Cell type:code id: tags:
``` python
preds_test = create_predictions(cnn, dg_test)
```
%% Cell type:code id: tags:
``` python
skill_by_year(preds_test)
```
%% Cell type:markdown id: tags:
# Submission
%% Cell type:code id: tags:
``` python
preds_test.sizes # expect: category(3), longitude, latitude, lead_time(2), forecast_time (53)
```
%% Cell type:code id: tags:
``` python
from scripts import assert_predictions_2020
assert_predictions_2020(preds_test)
```
%% Cell type:code id: tags:
``` python
preds_test.to_netcdf('../submissions/ML_prediction_2020.nc')
```
%% Cell type:code id: tags:
``` python
#!git add ../submissions/ML_prediction_2020.nc
#!git add ML_forecast_template.ipynb
```
%% Cell type:code id: tags:
``` python
#!git commit -m "commit submission for my_method_name" # whatever message you want
```
%% Cell type:code id: tags:
``` python
#!git tag "submission-my_method_name-0.0.1" # if this is to be checked by scorer, only the last submitted==tagged version will be considered
```
%% Cell type:code id: tags:
``` python
#!git push --tags
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:markdown id: tags:
# Reproducibility
%% Cell type:markdown id: tags:
## memory
%% Cell type:code id: tags:
``` python
# https://phoenixnap.com/kb/linux-commands-check-memory-usage
!free -g
```
%% Cell type:markdown id: tags:
## CPU
%% Cell type:code id: tags:
``` python
!lscpu
```
%% Cell type:markdown id: tags:
## software
%% Cell type:code id: tags:
``` python
!conda list
```
%% Cell type:code id: tags:
``` python
```
......
This diff is collapsed.
This diff is collapsed.
%% Cell type:markdown id: tags:
# Train ML model to correct predictions of week 3-4 & 5-6
This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/).
%% Cell type:markdown id: tags:
# Synopsis
%% Cell type:markdown id: tags:
## Method: `mean bias reduction`
- calculate the mean bias from 2000-2019 deterministic ensemble mean forecast
- remove that mean bias from 2020 forecast deterministic ensemble mean forecast
- no Machine Learning used here
%% Cell type:markdown id: tags:
## Data used
type: renku datasets
Training-input for Machine Learning model:
- hindcasts of models:
- ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`
Forecast-input for Machine Learning model:
- real-time 2020 forecasts of models:
- ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`
Compare Machine Learning model forecast against against ground truth:
- `CPC` observations:
- `hindcast-like-observations_biweekly_deterministic.zarr`
- `forecast-like-observations_2020_biweekly_deterministic.zarr`
%% Cell type:markdown id: tags:
## Resources used
for training, details in reproducibility
- platform: MPI-M supercompute 1 Node
- memory: 64 GB
- processors: 36 CPU
- storage required: 10 GB
%% Cell type:markdown id: tags:
## Safeguards
All points have to be [x] checked. If not, your submission is invalid.
Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.
(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)
%% Cell type:markdown id: tags:
### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1)
If the organizers suspect overfitting, your contribution can be disqualified.
- [x] We didnt use 2020 observations in training (explicit overfitting and cheating)
- [x] We didnt repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)
- [x] We provide RPSS scores for the training period with script `skill_by_year`, see in section 6.3 `predict`.
- [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).
- [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.
- [x] We did use `test` explicitly in training or implicitly in incrementally adjusting parameters.
- [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)).
%% Cell type:markdown id: tags:
### Safeguards for Reproducibility
Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize
- [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)
- [x] Code is well documented, readable and reproducible.
- [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train.
%% Cell type:markdown id: tags:
# Imports
%% Cell type:code id: tags:
``` python
import xarray as xr
xr.set_options(display_style='text')
import numpy as np
from dask.utils import format_bytes
import xskillscore as xs
```
%% Cell type:markdown id: tags:
# Get training data
preprocessing of input data may be done in separate notebook/script
%% Cell type:markdown id: tags:
## Hindcast
get weekly initialized hindcasts
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
!renku storage pull ../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr
```
%% Cell type:code id: tags:
``` python
hind_2000_2019 = xr.open_zarr("../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr", consolidated=True)
```
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
!renku storage pull ../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr
```
%% Cell type:code id: tags:
``` python
fct_2020 = xr.open_zarr("../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr", consolidated=True)
```
%% Cell type:markdown id: tags:
## Observations
corresponding to hindcasts
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr
```
%% Cell type:code id: tags:
``` python
obs_2000_2019 = xr.open_zarr("../data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr", consolidated=True)
```
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
!renku storage pull ../data/forecast-like-observations_2020_biweekly_deterministic.zarr
```
%% Cell type:code id: tags:
``` python
obs_2020 = xr.open_zarr("../data/forecast-like-observations_2020_biweekly_deterministic.zarr", consolidated=True)
```
%% Cell type:markdown id: tags:
# no ML model
%% Cell type:markdown id: tags:
Here, we just remove the mean bias from the ensemble mean forecast.
%% Cell type:code id: tags:
``` python
bias_2000_2019 = (hind_2000_2019.mean('realization') - obs_2000_2019).groupby('forecast_time.weekofyear').mean().compute()
```
%% Output
/work/mh0727/m300524/conda-envs/s2s-ai/lib/python3.7/site-packages/xarray/core/accessor_dt.py:381: FutureWarning: dt.weekofyear and dt.week have been deprecated. Please use dt.isocalendar().week instead.
FutureWarning,
/work/mh0727/m300524/conda-envs/s2s-ai/lib/python3.7/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
%% Cell type:markdown id: tags:
## `predict`
Create predictions and print `mean(variable, lead_time, longitude, weighted latitude)` RPSS for all years as calculated by `skill_by_year`. For now RPS, todo: change to RPSS.
Create predictions and print `mean(variable, lead_time, longitude, weighted latitude)` RPSS for all years as calculated by `skill_by_year`.
%% Cell type:code id: tags:
``` python
from scripts import make_probabilistic
```
%% Cell type:code id: tags:
``` python
!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc
```
%% Cell type:code id: tags:
``` python
cache_path='../data'
tercile_file = f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc'
tercile_edges = xr.open_dataset(tercile_file)
```
%% Cell type:code id: tags:
``` python
# this is not useful but results have expected dimensions
# actually train for each lead_time
def create_predictions(fct, bias):
preds = fct - bias.sel(weekofyear=fct.forecast_time.dt.weekofyear)
preds = make_probabilistic(preds, tercile_edges)
return preds
```
%% Cell type:markdown id: tags:
### `predict` training period in-sample
%% Cell type:code id: tags:
``` python
!renku storage pull ../data/forecast-like-observations_2020_biweekly_terciled.nc
```
%% Cell type:code id: tags:
``` python
!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr
```
%% Cell type:code id: tags:
``` python
from scripts import skill_by_year
```
%% Cell type:code id: tags:
``` python
preds_is = create_predictions(hind_2000_2019, bias_2000_2019).compute()
```
%% Output
/work/mh0727/m300524/conda-envs/s2s-ai/lib/python3.7/site-packages/xarray/core/accessor_dt.py:381: FutureWarning: dt.weekofyear and dt.week have been deprecated. Please use dt.isocalendar().week instead.
FutureWarning,
/work/mh0727/m300524/conda-envs/s2s-ai/lib/python3.7/site-packages/xarray/core/accessor_dt.py:381: FutureWarning: dt.weekofyear and dt.week have been deprecated. Please use dt.isocalendar().week instead.
FutureWarning,
%% Cell type:code id: tags:
``` python
skill_by_year(preds_is)
```
%% Output
RPSS
year
2000 0.072817
2001 0.002777
2002 -0.001713
2003 -0.008863
2004 -0.075933
2005 0.040779
2006 0.020809
2007 -0.079551
2008 0.025663
2009 -0.007614
2010 0.048613
2011 -0.058636
2012 -0.079176
2013 0.005424
2014 0.002710
2015 0.028784
2016 0.051436
2017 -0.090526
2018 -0.121472
2019 0.004270
%% Cell type:markdown id: tags:
### `predict` test
%% Cell type:code id: tags:
``` python
preds_test = create_predictions(fct_2020, bias_2000_2019)
```
%% Output
/work/mh0727/m300524/conda-envs/s2s-ai/lib/python3.7/site-packages/xarray/core/accessor_dt.py:381: FutureWarning: dt.weekofyear and dt.week have been deprecated. Please use dt.isocalendar().week instead.
FutureWarning,
/work/mh0727/m300524/conda-envs/s2s-ai/lib/python3.7/site-packages/xarray/core/accessor_dt.py:381: FutureWarning: dt.weekofyear and dt.week have been deprecated. Please use dt.isocalendar().week instead.
FutureWarning,
%% Cell type:code id: tags:
``` python
skill_by_year(preds_test)
```
%% Output
RPSS
year
2020 0.037885
%% Cell type:markdown id: tags:
# Submission
%% Cell type:code id: tags:
``` python
from scripts import assert_predictions_2020
assert_predictions_2020(preds_test)
```
%% Cell type:code id: tags:
``` python
del preds_test['weekofyear']
```
%% Cell type:code id: tags:
``` python
preds_test.to_netcdf('../submissions/ML_prediction_2020.nc')
```
%% Cell type:code id: tags:
``` python
#!git add ../submissions/ML_prediction_2020.nc
#!git add mean_bias_reduction.ipynb
```
%% Cell type:code id: tags:
``` python
#!git commit -m "template_test no ML mean bias reduction" # whatever message you want
```
%% Cell type:code id: tags:
``` python
#!git tag "submission-no_ML_mean_bias_reduction-0.0.1" # if this is to be checked by scorer, only the last submitted==tagged version will be considered
```
%% Cell type:code id: tags:
``` python
#!git push --tags
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:markdown id: tags:
# Reproducibility
%% Cell type:markdown id: tags:
## memory
%% Cell type:code id: tags:
``` python
# https://phoenixnap.com/kb/linux-commands-check-memory-usage
!free -g
```
%% Output
total used free shared buffers cached
Mem: 62 20 42 0 0 6
-/+ buffers/cache: 13 48
Swap: 0 0 0
%% Cell type:markdown id: tags:
## CPU
%% Cell type:code id: tags:
``` python
!lscpu
```
%% Output
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 63
Model name: Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
Stepping: 2
CPU MHz: 1200.000
BogoMIPS: 4988.09
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 30720K
NUMA node0 CPU(s): 0-11,24-35
NUMA node1 CPU(s): 12-23,36-47
%% Cell type:markdown id: tags:
## software
%% Cell type:code id: tags:
``` python
!conda list
```
%% Cell type:code id: tags:
``` python
```
......
......@@ -74,6 +74,11 @@ def skill_by_year(preds):
# ML probabilities
fct_p = preds
# check inputs
assert_predictions_2020(obs_p)
assert_predictions_2020(fct_p)
# climatology
clim_p = xr.DataArray([1/3, 1/3, 1/3], dims='category', coords={'category':['below normal', 'near normal', 'above normal']}).to_dataset(name='tp')
clim_p['t2m'] = clim_p['tp']
......@@ -82,21 +87,22 @@ def skill_by_year(preds):
# rps_ML
rps_ML = xs.rps(obs_p, fct_p, category_edges=None, dim=[], input_distributions='p').compute()
# rps_clim
rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim='forecast_time', input_distributions='p').compute()
# rpss
rpss = 1 - rps_ML / rps_clim
rpss = rpss.groupby('forecast_time.year').mean()
rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute()
# rpss
rpss = 1 - (rps_ML / rps_clim)
# cleaning
# check for -inf grid cells
if (rpss==-np.inf).to_array().any():
print(f'find N={(rpss == rpss.min()).sum()} -inf grid cells')
# https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7
# penalize
penalize = obs_p.where(fct_p!=1, other=-10).mean('category')
rpss = rpss.where(penalize!=0,other=-10)
# dirty fix
rpss = rpss.clip(-1, 1)
# what to do with requested grid cells where NaN is submitted? also penalize, todo: https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7
# clip
rpss = rpss.clip(-10, 1)
# average over all forecasts
rpss = rpss.groupby('forecast_time.year').mean()
# weighted area mean
weights = np.cos(np.deg2rad(np.abs(rpss.latitude)))
......@@ -106,47 +112,62 @@ def skill_by_year(preds):
return scores.to_dataframe('RPSS')
def assert_predictions_2020(preds_test):
def assert_predictions_2020(preds_test, exclude='weekofyear'):
"""Check the variables, coordinates and dimensions of 2020 predictions."""
from xarray.testing import assert_equal # doesnt care about attrs but checks coords
# is dataset
assert isinstance(preds_test, xr.Dataset)
# has both vars: tp and t2m
assert 'tp' in preds_test.data_vars
assert 't2m' in preds_test.data_vars
if 'data_vars' in exclude:
assert 'tp' in preds_test.data_vars
assert 't2m' in preds_test.data_vars
## coords
# ignore weekofyear coord if not dim
if 'weekofyear' in exclude and 'weekofyear' in preds_test.coords and 'weekofyear' not in preds_test.dims:
preds_test = preds_test.drop('weekofyear')
# forecast_time
d = pd.date_range(start='2020-01-02', freq='7D', periods=53)
forecast_time = xr.DataArray(d, dims='forecast_time', coords={'forecast_time':d})
assert (forecast_time == preds_test['forecast_time']).all()
if 'forecast_time' in exclude:
d = pd.date_range(start='2020-01-02', freq='7D', periods=53)
forecast_time = xr.DataArray(d, dims='forecast_time', coords={'forecast_time':d}, name='forecast_time')
assert_equal(forecast_time, preds_test['forecast_time'])
# longitude
lon = np.arange(0., 360., 1.5)
longitude = xr.DataArray(lon, dims='longitude', coords={'longitude': lon})
assert (longitude == preds_test['longitude']).all()
if 'longitude' in exclude:
lon = np.arange(0., 360., 1.5)
longitude = xr.DataArray(lon, dims='longitude', coords={'longitude': lon}, name='longitude')
assert_equal(longitude, preds_test['longitude'])
# latitude
lat = np.arange(90., -90.1, 1.5)
latitude = xr.DataArray(lat, dims='latitude', coords={'latitude': lat})
assert (latitude == preds_test['latitude']).all()
if 'latitude' in exclude:
lat = np.arange(-90., 90.1, 1.5)[::-1]
latitude = xr.DataArray(lat, dims='latitude', coords={'latitude': lat}, name='latitude')
assert_equal(latitude, preds_test['latitude'])
# lead_time
lead = [pd.Timedelta(f'{i} d') for i in [14, 28]]
lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead})
assert (lead_time == preds_test['lead_time']).all()
if 'lead_time' in exclude:
lead = [pd.Timedelta(f'{i} d') for i in [14, 28]]
lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead}, name='lead_time')
assert_equal(lead_time, preds_test['lead_time'])
# category
cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12')
category = xr.DataArray(cat, dims='category', coords={'category': cat})
assert (category == preds_test['category']).all()
if 'category' in exclude:
cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12')
category = xr.DataArray(cat, dims='category', coords={'category': cat}, name='category')
assert_equal(category, preds_test['category'])
# size
from dask.utils import format_bytes
size_in_MB = float(format_bytes(preds_test.nbytes).split(' ')[0])
assert size_in_MB > 50
assert size_in_MB < 250
if 'size' in exclude:
from dask.utils import format_bytes
size_in_MB = float(format_bytes(preds_test.nbytes).split(' ')[0])
# todo: refine for dtypes
assert size_in_MB > 50
assert size_in_MB < 250
# no other dims
assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set()
\ No newline at end of file
if 'dims' in exclude:
assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment