import xarray as xr import xskillscore as xs import numpy as np import pandas as pd import argparse from pathlib import Path def assert_predictions_2020(preds_test, exclude='week'): """Check the variables, coordinates and dimensions of 2020 predictions.""" from xarray.testing import assert_equal # doesnt care about attrs but checks coords if isinstance(exclude, str): exclude = [exclude] # is dataset assert isinstance(preds_test, xr.Dataset) # has both vars: tp and t2m if 'data_vars' not in exclude: assert 'tp' in preds_test.data_vars assert 't2m' in preds_test.data_vars ## coords # ignore week coord if not dim if 'week' in exclude and 'week' in preds_test.coords and 'week' not in preds_test.dims: preds_test = preds_test.drop('week') # forecast_time if 'forecast_time' not in exclude: d = pd.date_range(start='2020-01-02', freq='7D', periods=53) forecast_time = xr.DataArray(d, dims='forecast_time', coords={'forecast_time':d}, name='forecast_time') assert_equal(forecast_time, preds_test['forecast_time']) # longitude if 'longitude' not in exclude: lon = np.arange(0., 360., 1.5) longitude = xr.DataArray(lon, dims='longitude', coords={'longitude': lon}, name='longitude') assert_equal(longitude, preds_test['longitude']) # latitude if 'latitude' not in exclude: #lat = np.arange(-90., 90.1, 1.5)[::-1] #latitude = xr.DataArray(lat, dims='latitude', coords={'latitude': lat}, name='latitude') #assert_equal(latitude, preds_test['latitude']) # above too strict, allow for not submitting south of 60S assert (preds_test.latitude.diff('latitude')==-1.5).all() assert 60 in preds_test.latitude assert -60 in preds_test.latitude # lead_time if 'lead_time' not in exclude: lead = [pd.Timedelta(f'{i} d') for i in [14, 28]] lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead}, name='lead_time') assert_equal(lead_time, preds_test['lead_time']) # category if 'category' not in exclude: cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12') category = xr.DataArray(cat, dims='category', coords={'category': cat}, name='category') assert_equal(category, preds_test['category']) # size if 'size' not in exclude: from dask.utils import format_bytes size_in_MB = float(format_bytes(preds_test.nbytes).split(' ')[0]) # todo: refine for dtypes assert size_in_MB > 30 assert size_in_MB < 250 # no other dims if 'dims' in exclude: assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("prediction", help="The netcdf file with predictions") args = parser.parse_args() cache_path = "scoring" observations_terciled_fin = Path(f'{cache_path}/forecast-like-observations_2020_biweekly_terciled.nc') obs_p = xr.open_dataset(observations_terciled_fin) fct_p = xr.open_dataset(args.prediction) # check inputs assert_predictions_2020(obs_p) assert_predictions_2020(fct_p) # climatology forecast clim_p = xr.DataArray([1/3, 1/3, 1/3], dims='category', coords={'category':['below normal', 'near normal', 'above normal']}).to_dataset(name='tp') clim_p['t2m'] = clim_p['tp'] # submission rps rps_ML = xs.rps(obs_p, fct_p, category_edges=None, dim=[], input_distributions='p').compute() # climatology forecast rps rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute() ## RPSS # penalize # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7 expect = obs_p.sum('category') expect = expect.where(expect > 0.98).where(expect < 1.02) # should be True if not all NaN # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/50 rps_ML = rps_ML.where(expect, other=2) # assign RPS=2 where value was expected but NaN found # following Weigel 2007: https://doi.org/10.1175/MWR3280.1 rpss = 1 - (rps_ML.mean('forecast_time') / rps_clim.mean('forecast_time')) # clip rpss = rpss.clip(-10, 1) # weighted area mean weights = np.cos(np.deg2rad(np.abs(rpss.latitude))) # spatially weighted score averaged over lead_times and variables to one single value scores = rpss.sel(latitude=slice(None, -60)).weighted(weights).mean('latitude').mean('longitude') scores = scores.to_array().mean(['lead_time', 'variable']).reset_coords(drop=True) # score transfered to leaderboard print(scores.item())