import xarray as xr
import xskillscore as xs
import numpy as np
import pandas as pd

import argparse
from pathlib import Path

def assert_predictions_2020(preds_test, exclude='week'):
    """Check the variables, coordinates and dimensions of 2020 predictions."""
    from xarray.testing import assert_equal # doesnt care about attrs but checks coords
    if isinstance(exclude, str):
        exclude = [exclude]

    # is dataset
    assert isinstance(preds_test, xr.Dataset)

    # has both vars: tp and t2m
    if 'data_vars' not in exclude:
        assert 'tp' in preds_test.data_vars
        assert 't2m' in preds_test.data_vars

    ## coords
    # ignore week coord if not dim
    if 'week' in exclude and 'week' in preds_test.coords and 'week' not in preds_test.dims:
        preds_test = preds_test.drop('week')

    # forecast_time
    if 'forecast_time' not in exclude:
        d = pd.date_range(start='2020-01-02', freq='7D', periods=53)
        forecast_time = xr.DataArray(d, dims='forecast_time', coords={'forecast_time':d}, name='forecast_time')
        assert_equal(forecast_time,  preds_test['forecast_time'])

    # longitude
    if 'longitude' not in exclude:
        lon = np.arange(0., 360., 1.5)
        longitude = xr.DataArray(lon, dims='longitude', coords={'longitude': lon}, name='longitude')
        assert_equal(longitude, preds_test['longitude'])

    # latitude
    if 'latitude' not in exclude:
        #lat = np.arange(-90., 90.1, 1.5)[::-1]
        #latitude = xr.DataArray(lat, dims='latitude', coords={'latitude': lat}, name='latitude')
        #assert_equal(latitude, preds_test['latitude'])
        # above too strict, allow for not submitting south of 60S
        assert (preds_test.latitude.diff('latitude')==-1.5).all()
        assert 60 in preds_test.latitude
        assert -60 in preds_test.latitude

    # lead_time
    if 'lead_time' not in exclude:
        lead = [pd.Timedelta(f'{i} d') for i in [14, 28]]
        lead_time = xr.DataArray(lead, dims='lead_time', coords={'lead_time': lead}, name='lead_time')
        assert_equal(lead_time, preds_test['lead_time'])

    # category
    if 'category' not in exclude:
        cat = np.array(['below normal', 'near normal', 'above normal'], dtype='<U12')
        category = xr.DataArray(cat, dims='category', coords={'category': cat}, name='category')
        assert_equal(category, preds_test['category'])

    # size
    if 'size' not in exclude:
        from dask.utils import format_bytes
        size_in_MB = float(format_bytes(preds_test.nbytes).split(' ')[0])
        # todo: refine for dtypes
        assert size_in_MB > 30
        assert size_in_MB < 250

    # no other dims
    if 'dims' in exclude:
        assert set(preds_test.dims) - {'category', 'forecast_time', 'latitude', 'lead_time', 'longitude'} == set()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("prediction", help="The netcdf file with predictions")
    args = parser.parse_args()

    cache_path = "scoring"

    observations_terciled_fin = Path(f'{cache_path}/forecast-like-observations_2020_biweekly_terciled.nc')

    obs_p = xr.open_dataset(observations_terciled_fin)

    fct_p = xr.open_dataset(args.prediction)

    # check inputs
    assert_predictions_2020(obs_p)
    assert_predictions_2020(fct_p)

    # climatology forecast
    clim_p = xr.DataArray([1/3, 1/3, 1/3], dims='category', coords={'category':['below normal', 'near normal', 'above normal']}).to_dataset(name='tp')
    clim_p['t2m'] = clim_p['tp']

    # submission rps
    rps_ML = xs.rps(obs_p, fct_p, category_edges=None, dim=[], input_distributions='p').compute()
    # climatology forecast rps
    rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute()

    ## RPSS
    # penalize # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7
    expect = obs_p.sum('category')
    expect = expect.where(expect > 0.98).where(expect < 1.02)  # should be True if not all NaN

    # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/50
    rps_ML = rps_ML.where(expect, other=2)  # assign RPS=2 where value was expected but NaN found

    # following Weigel 2007: https://doi.org/10.1175/MWR3280.1
    rpss = 1 - (rps_ML.mean('forecast_time') / rps_clim.mean('forecast_time'))
    # clip
    rpss = rpss.clip(-10, 1)

    # weighted area mean
    weights = np.cos(np.deg2rad(np.abs(rpss.latitude)))
    # spatially weighted score averaged over lead_times and variables to one single value
    scores = rpss.sel(latitude=slice(None, -60)).weighted(weights).mean('latitude').mean('longitude')
    scores = scores.to_array().mean(['lead_time', 'variable']).reset_coords(drop=True)
    # score transfered to leaderboard
    print(scores.item())