Skip to content
Snippets Groups Projects
scoring_script.py 4.7 KiB
Newer Older
Tasko Olevski's avatar
Tasko Olevski committed
import xarray as xr
import xskillscore as xs
import numpy as np
import pandas as pd
Aaron Spring's avatar
Aaron Spring committed

Tasko Olevski's avatar
Tasko Olevski committed
import argparse
from pathlib import Path

def assert_predictions_2020(preds_test, exclude='week'):
Aaron Spring's avatar
Aaron Spring committed
    """Check the variables, coordinates and dimensions of 2020 predictions."""
    from xarray.testing import assert_equal # doesnt care about attrs but checks coords
    if isinstance(exclude, str):
        exclude = [exclude]
Aaron Spring's avatar
Aaron Spring committed
    # is dataset
    assert isinstance(preds_test, xr.Dataset)

    # has both vars: tp and t2m
    if 'data_vars' not in exclude:
Aaron Spring's avatar
Aaron Spring committed
        assert 'tp' in preds_test.data_vars
        assert 't2m' in preds_test.data_vars
Aaron Spring's avatar
Aaron Spring committed
    ## coords
    # ignore week coord if not dim
    if 'week' in exclude and 'week' in preds_test.coords and 'week' not in preds_test.dims:
        preds_test = preds_test.drop('week')
Aaron Spring's avatar
Aaron Spring committed
    # forecast_time
    if 'forecast_time' not in exclude:
Aaron Spring's avatar
Aaron Spring committed
        d = pd.date_range(start='2020-01-02', freq='7D', periods=53)
        forecast_time = xr.DataArray(d, dims='forecast_time', coords={'forecast_time':d}, name='forecast_time')
        assert_equal(forecast_time,  preds_test['forecast_time'])

    # longitude
    if 'longitude' not in exclude:
Aaron Spring's avatar
Aaron Spring committed
        lon = np.arange(0., 360., 1.5)
        longitude = xr.DataArray(lon, dims='longitude', coords={'longitude': lon}, name='longitude')
        assert_equal(longitude, preds_test['longitude'])

    # latitude
    if 'latitude' not in exclude:
        #lat = np.arange(-90., 90.1, 1.5)[::-1]
        #latitude = xr.DataArray(lat, dims='latitude', coords={'latitude': lat}, name='latitude')
        #assert_equal(latitude, preds_test['latitude'])
        # above too strict, allow for not submitting south of 60S
        assert (preds_test.latitude.diff('latitude')==-1.5).all()
        assert 60 in preds_test.latitude
        assert -60 in preds_test.latitude

Aaron Spring's avatar
Aaron Spring committed
    # lead_time
    if 'lead_time' not in exclude:
Aaron Spring's avatar
Aaron Spring committed
        lead = [pd.Timedelta(f'{i} d') for i in [14, 28]]
        lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead}, name='lead_time')
        assert_equal(lead_time, preds_test['lead_time'])
Aaron Spring's avatar
Aaron Spring committed
    # category
    if 'category' not in exclude:
Aaron Spring's avatar
Aaron Spring committed
        cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12')
        category = xr.DataArray(cat, dims='category', coords={'category': cat}, name='category')
        assert_equal(category, preds_test['category'])
Aaron Spring's avatar
Aaron Spring committed
    # size
Aaron Spring's avatar
Aaron Spring committed
        from dask.utils import format_bytes
        size_in_MB = float(format_bytes(preds_test.nbytes).split(' ')[0])
        # todo: refine for dtypes
Aaron Spring's avatar
Aaron Spring committed
        assert size_in_MB < 250
Aaron Spring's avatar
Aaron Spring committed
    # no other dims
    if 'dims' in exclude:
        assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set()

Tasko Olevski's avatar
Tasko Olevski committed
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("prediction", help="The netcdf file with predictions")
    args = parser.parse_args()

Aaron Spring's avatar
Aaron Spring committed
    cache_path = "scoring"

    observations_terciled_fin = Path(f'{cache_path}/forecast-like-observations_2020_biweekly_terciled.nc')
Tasko Olevski's avatar
Tasko Olevski committed

Aaron Spring's avatar
Aaron Spring committed
    obs_p = xr.open_dataset(observations_terciled_fin)
Tasko Olevski's avatar
Tasko Olevski committed

Aaron Spring's avatar
Aaron Spring committed
    fct_p = xr.open_dataset(args.prediction)
Tasko Olevski's avatar
Tasko Olevski committed

Aaron Spring's avatar
Aaron Spring committed
    # check inputs
    assert_predictions_2020(obs_p)
    assert_predictions_2020(fct_p)

    # climatology forecast
Aaron Spring's avatar
Aaron Spring committed
    clim_p = xr.DataArray([1/3, 1/3, 1/3], dims='category', coords={'category':['below normal', 'near normal', 'above normal']}).to_dataset(name='tp')
    clim_p['t2m'] = clim_p['tp']
Tasko Olevski's avatar
Tasko Olevski committed

Aaron Spring's avatar
Aaron Spring committed
    # submission rps
    rps_ML = xs.rps(obs_p, fct_p, category_edges=None, dim=[], input_distributions='p').compute()
    # climatology forecast rps
    rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute()
Aaron Spring's avatar
Aaron Spring committed

Aaron Spring's avatar
Aaron Spring committed
    ## RPSS
    # penalize # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7
    expect = obs_p.sum('category')
    expect = expect.where(expect > 0.98).where(expect < 1.02)  # should be True if not all NaN
Tasko Olevski's avatar
Tasko Olevski committed

Aaron Spring's avatar
Aaron Spring committed
    # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/50
    rps_ML = rps_ML.where(expect, other=2)  # assign RPS=2 where value was expected but NaN found
Tasko Olevski's avatar
Tasko Olevski committed

Aaron Spring's avatar
Aaron Spring committed
    # following Weigel 2007: https://doi.org/10.1175/MWR3280.1
    rpss = 1 - (rps_ML.mean('forecast_time') / rps_clim.mean('forecast_time'))
Aaron Spring's avatar
Aaron Spring committed
    # clip
    rpss = rpss.clip(-10, 1)
Tasko Olevski's avatar
Tasko Olevski committed

    # weighted area mean
Aaron Spring's avatar
Aaron Spring committed
    weights = np.cos(np.deg2rad(np.abs(rpss.latitude)))
Tasko Olevski's avatar
Tasko Olevski committed
    # spatially weighted score averaged over lead_times and variables to one single value
Aaron Spring's avatar
Aaron Spring committed
    scores = rpss.sel(latitude=slice(None, -60)).weighted(weights).mean('latitude').mean('longitude')
    scores = scores.to_array().mean(['lead_time', 'variable']).reset_coords(drop=True)
Aaron Spring's avatar
Aaron Spring committed
    # score transfered to leaderboard
    print(scores.item())