{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import pandas as pd\n",
    "import altair as alt\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "from covid_19_utils.converters import CaseConverter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "html_credits=HTML('''\n",
    "<p style=\"font-size: smaller\">Data Sources: \n",
    "  <a href=\"https://covidtracking.com\">The COVID Tracking Project</a>\n",
    "<br>\n",
    "Analysis and Visualization:\n",
    "  <a href=\"https://renkulab.io/projects/covid-19/covid-19-public-data\">Covid-19 Public Data Collaboration Project</a>\n",
    "</p>''')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "data_path = '../data/covidtracking'\n",
    "atlas_path = '../data/atlas'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read in the data \n",
    "converter = CaseConverter(atlas_path)\n",
    "data_df = converter.read_convert(data_path)\n",
    "\n",
    "# referring to \"state\" will make more sense in this notebook\n",
    "data_df = data_df.rename(columns={\"region_label\": \"state\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute daily differences\n",
    "tdf = data_df.sort_values(['state', 'date'], ascending=[True, False]).set_index(['state', 'date'])\n",
    "diffs_df = tdf[['positive', 'deceased', 'positive_100k', 'deceased_100k']].groupby(level='state').diff(periods=-1).dropna(how='all')\n",
    "tdf_diff=tdf.join(diffs_df, rsuffix='_diff').reset_index()\n",
    "\n",
    "# \"Normalizing\" the total tests\n",
    "tdf_diff['total_10'] = tdf_diff['tested']/10.\n",
    "\n",
    "# Daily totals\n",
    "daily_totals = tdf_diff.groupby('date').sum()\n",
    "daily_totals.reset_index(level=0, inplace=True)\n",
    "\n",
    "# National daily totals\n",
    "nation_df = data_df.groupby('date').sum()\n",
    "nation_df['state']='All US'\n",
    "nation_df = nation_df.reset_index()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Covid-19 Cases in U.S.\n",
    "\n",
    "The case data from the U.S. is obtained from https://covidtracking.com, a public crowd-sourced covid-19 dataset. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Growth trends"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make dataframe for text labels on chart - hand edit these label locations\n",
    "textLabels_df = pd.DataFrame(\n",
    "    [[10,6000,'doubles every day'],\n",
    "     [36,50000,'doubles every 3 days'],\n",
    "     [34,100, 'doubles every week']],\n",
    "    columns =['labelX', 'labelY','labelText']\n",
    ")\n",
    "\n",
    "startCase = 2000\n",
    "\n",
    "# make dataframe of states with points >=10 deceaseds\n",
    "deceased10_df = data_df.loc[data_df['deceased']>=startCase]\n",
    "\n",
    "# group deceased10 dataframe by state and then increasing order of date\n",
    "deceased10_df = deceased10_df.sort_values(by=['state','date'])\n",
    "\n",
    "# add US to that dataframe\n",
    "nationdeceased10_df = nation_df.loc[nation_df['deceased']>=startCase]\n",
    "deceased10_df= pd.concat ([deceased10_df,nationdeceased10_df])\n",
    "\n",
    "deceased10_df = deceased10_df.reset_index()\n",
    "\n",
    "# make a list of the states with 10 or more deceaseds\n",
    "state_list = list(set(deceased10_df['state']))\n",
    "\n",
    "# add a column for the number of days since the 10th deceased for each state\n",
    "for state, df in deceased10_df.groupby('state'):\n",
    "    deceased10_df.loc[df.index,'sinceDay0'] = range(0, len(df))\n",
    "deceased10_df = deceased10_df.astype({'sinceDay0': 'int32'})\n",
    "\n",
    "#Now create plotlines for each state since 10 deceaseds\n",
    "lineChart = alt.Chart(deceased10_df,title=f'US States: Cumulative Deaths Since {startCase}th Death').mark_line(interpolate='basis').encode(\n",
    "    alt.X('sinceDay0:Q', axis=alt.Axis(title=f'Days Since {startCase}th Death')),\n",
    "    alt.Y('deceased:Q',\n",
    "         axis = alt.Axis(title='Cumulative Deaths'),\n",
    "         scale=alt.Scale(type='log')),\n",
    "    tooltip=['state', 'sinceDay0', 'deceased', 'positive'],\n",
    "    color = 'state'\n",
    ").properties(width=800,height=400)\n",
    "\n",
    "## Create a layer with the lines for doubling every day and doubling every week\n",
    "\n",
    "# Compute theoretical trends of doubling every day, 3 days, week\n",
    "days = {'day':[1,2,3,4,5,10,15,20, max(deceased10_df.sinceDay0)+5]}\n",
    "logRuleDay_df = pd.DataFrame(days, columns=['day'])\n",
    "logRuleDay_df['case']= startCase * pow(2,logRuleDay_df['day'])\n",
    "logRuleDay_df['doubling period']='every day'\n",
    "\n",
    "logRule3Days_df = pd.DataFrame(days, columns=['day'])\n",
    "logRule3Days_df['case']= startCase * pow(2,(logRule3Days_df['day'])/3)\n",
    "logRule3Days_df['doubling period']='three days'\n",
    "\n",
    "logRuleWeek_df = pd.DataFrame(days, columns=['day'])\n",
    "logRuleWeek_df['case']= startCase * pow(2,(logRuleWeek_df['day'])/7)\n",
    "logRuleWeek_df['doubling period']='every week'\n",
    "\n",
    "logRules_df = pd.concat([logRuleDay_df, logRule3Days_df, logRuleWeek_df])\n",
    "logRules_df = logRules_df.reset_index()\n",
    "\n",
    "\n",
    "ruleChart = alt.Chart(logRules_df).mark_line(opacity=0.2,clip=True).encode(\n",
    "    alt.X('day:Q',\n",
    "            scale=alt.Scale(domain=[1,max(deceased10_df.sinceDay0)+5])),\n",
    "    alt.Y('case', scale=alt.Scale(type='log',domain=[startCase,150000]),\n",
    "         ),\n",
    "    color = 'doubling period',\n",
    "    tooltip = ['doubling period'])        \n",
    "\n",
    "# create a layer for the state labels\n",
    "# 1) make dataframe with each state's max days\n",
    "# 2) make a chart layer with text of state name to right of each state's rightmost point\n",
    "stateLabels_df = deceased10_df[deceased10_df['sinceDay0'] == deceased10_df.groupby(['state'])['sinceDay0'].transform(max)]\n",
    "labelChart = alt.Chart(stateLabels_df).mark_text(align='left', baseline='middle', dx=10).encode(\n",
    "    x='sinceDay0',\n",
    "    y='deceased',\n",
    "    text='state',\n",
    "    color='state')\n",
    "\n",
    "#now put the text labels layer on top of state labels Chart\n",
    "labelChart = labelChart + alt.Chart(textLabels_df).mark_text(align='right', baseline='bottom', dx=0, size=18,opacity=0.5).encode(\n",
    "    x='labelX',\n",
    "    y='labelY',\n",
    "    text='labelText')\n",
    "\n",
    "\n",
    "## Create some tooltip behavior - show Y values on mouseover\n",
    "# Step 1: Selection that chooses nearest point based on value on x-axis\n",
    "nearest = alt.selection(type='single', nearest=True, on='mouseover',\n",
    "                            fields=['sinceDay0'])\n",
    "\n",
    "# Step 2: Transparent selectors across the chart. This is what tells us\n",
    "# the x-value of the cursor\n",
    "selectors = alt.Chart().mark_point().encode(\n",
    "    x=\"sinceDay0:Q\",\n",
    "    opacity=alt.value(0),\n",
    ").add_selection(\n",
    "    nearest\n",
    ")\n",
    "\n",
    "# Step 3: Add text, show values in column when it's the nearest point to \n",
    "# mouseover, else show blank\n",
    "text = lineChart.mark_text(align='center', dx=3, dy=-20).encode(\n",
    "    text=alt.condition(nearest, 'deceased', alt.value(' '))\n",
    ")\n",
    "\n",
    "\n",
    "#Finally, lets show the chart!\n",
    "\n",
    "chart = alt.layer(lineChart, selectors, text, data=deceased10_df)\n",
    "\n",
    "display(chart)\n",
    "display(html_credits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make dataframe for text labels on chart - hand edit these label locations\n",
    "textLabels_df = pd.DataFrame(\n",
    "    [[9,30000,'doubles every day'],\n",
    "     [28,31000,'doubles every 3 days'],\n",
    "     [32,1000, 'doubles every week']],\n",
    "    columns =['labelX', 'labelY','labelText']\n",
    ")\n",
    "\n",
    "startCase = 100000\n",
    "\n",
    "# make dataframe with only points >=100 positives\n",
    "positive100_df = data_df.loc[data_df['positive']>=startCase]\n",
    "\n",
    "## add US to that dataframe\n",
    "nationpos100_df = nation_df.loc[nation_df['positive']>=startCase]\n",
    "positive100_df= pd.concat ([positive100_df,nationpos100_df])\n",
    "\n",
    "# group positive100 dataframe by state and then increasing order of date\n",
    "positive100_df = positive100_df.sort_values(by=['state','date'])\n",
    "positive100_df = positive100_df.reset_index()\n",
    "\n",
    "# make a list of the states with 10 or more deaths (don't really need this)\n",
    "# state_list = list(set(positive100_df['state']))\n",
    "\n",
    "# add a column for the number of days since the 100th case for each state\n",
    "for state, df in positive100_df.groupby('state'):\n",
    "    positive100_df.loc[df.index,'sinceDay0'] = range(0, len(df))\n",
    "positive100_df = positive100_df.astype({'sinceDay0': 'int32'})\n",
    "\n",
    "    \n",
    "# Now create plotlines for each state since 10 deaths\n",
    "lineChart = alt.Chart(positive100_df, title=f\"US States: total cases since {startCase}th case\").mark_line(interpolate='basis').encode(\n",
    "    alt.X('sinceDay0:Q', axis=alt.Axis(title=f'Days since {startCase}th case')),\n",
    "    alt.Y('positive:Q',\n",
    "          axis = alt.Axis(title='Cumulative positive cases'),\n",
    "          scale=alt.Scale(type='log')),\n",
    "    tooltip=['state', 'sinceDay0', 'deceased', 'positive'],\n",
    "    color = 'state'\n",
    ").properties(width=800,height=400)\n",
    "\n",
    "## Create a layer with the lines for doubling every day and doubling every week\n",
    "# make dataframe with lines to indicate doubling every day, 3 days, week \n",
    "\n",
    "days = {'day':[1,2,3,4,5,10,15,20, max(positive100_df.sinceDay0)+5]}\n",
    "\n",
    "logRuleDay_df = pd.DataFrame (days, columns=['day'])\n",
    "logRuleDay_df['case']= startCase * pow(2,logRuleDay_df['day'])\n",
    "logRuleDay_df['doubling period']='every day'\n",
    "\n",
    "logRule3Days_df = pd.DataFrame (days, columns=['day'])\n",
    "logRule3Days_df['case']= startCase * pow(2,(logRule3Days_df['day'])/3)\n",
    "logRule3Days_df['doubling period']='three days'\n",
    "\n",
    "logRuleWeek_df = pd.DataFrame (days, columns=['day'])\n",
    "logRuleWeek_df['case']= startCase * pow(2,(logRuleWeek_df['day'])/7)\n",
    "logRuleWeek_df['doubling period']='every week'\n",
    "\n",
    "logRules_df = pd.concat([logRuleDay_df, logRule3Days_df, logRuleWeek_df])\n",
    "logRules_df = logRules_df.reset_index()\n",
    "\n",
    "ruleChart = alt.Chart(logRules_df).mark_line(opacity=0.2,clip=True).encode(\n",
    "    alt.X('day:Q',\n",
    "            scale=alt.Scale(domain=[1, max(positive100_df.sinceDay0)+5])),\n",
    "    alt.Y('case', scale=alt.Scale(domain=[startCase,2000000], type='log'),\n",
    "         ),\n",
    "    color = 'doubling period')\n",
    "\n",
    "# create a layer for the state labels\n",
    "# 1) make dataframe with each state's max days\n",
    "# 2) make a chart layer with text of state name to right of each state's rightmost point\n",
    "stateLabels_df = positive100_df[positive100_df['sinceDay0'] == positive100_df.groupby(['state'])['sinceDay0'].transform(max)]\n",
    "labelChart = alt.Chart(stateLabels_df).mark_text(align='left', baseline='middle', dx=10).encode(\n",
    "    x='sinceDay0',\n",
    "    y='positive',\n",
    "    text='state',\n",
    "    color='state')\n",
    "\n",
    "#now put the text labels layer on top of state labels Chart\n",
    "labelChart = labelChart + alt.Chart(textLabels_df).mark_text(align='right', baseline='bottom', dx=0, size=18,opacity=0.5).encode(\n",
    "    x='labelX',\n",
    "    y='labelY',\n",
    "    text='labelText')\n",
    "\n",
    "#Create some tooltip behavior\n",
    "# Step 1: Selection that chooses nearest point based on value on x-axis\n",
    "nearest = alt.selection(type='single', nearest=True, on='mouseover',\n",
    "                            fields=['sinceDay0'])\n",
    "\n",
    "# Step 2: Transparent selectors across the chart. This is what tells us\n",
    "# the x-value of the cursor\n",
    "selectors = alt.Chart().mark_point().encode(\n",
    "    x=\"sinceDay0:Q\",\n",
    "    opacity=alt.value(0),\n",
    ").add_selection(\n",
    "    nearest\n",
    ")\n",
    "\n",
    "# Step 3: Add text, show values in Sex column when it's the nearest point to \n",
    "# mouseover, else show blank\n",
    "text = lineChart.mark_text(align='center', dx=3, dy=-20).encode(\n",
    "    text=alt.condition(nearest, 'positive', alt.value(' '))\n",
    ")\n",
    "\n",
    "\n",
    "#Finally, lets show the chart!\n",
    "\n",
    "chart = alt.layer(lineChart, selectors, text, data=positive100_df)\n",
    "#chart = alt.layer(lineChart, ruleChart, labelChart)\n",
    "chart.properties (width=400,height=800)\n",
    "display(chart)\n",
    "display(html_credits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Daily Cumulative Totals\n",
    "\n",
    "Cumulative reported totals of positive cases and deaths. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = alt.Chart(\n",
    "    daily_totals\n",
    ").mark_bar(size=2).encode(\n",
    "    alt.X('date', axis=alt.Axis(title='')\n",
    "    )\n",
    ").properties(\n",
    "    height=200,\n",
    "    width=400\n",
    ")\n",
    "\n",
    "cumulative = base.encode(alt.Y('positive', title = 'Cumulative cases'))\n",
    "cumulative_deaths = base.encode(alt.Y('deceased', title = 'Cumulative deaths'))\n",
    "rates = base.encode(alt.Y('positive_diff', title='Daily cases'))\n",
    "rates_deaths = base.encode(alt.Y('deceased_diff', title='Daily deaths'))\n",
    "chart = alt.vconcat(\n",
    "    cumulative | rates, cumulative_deaths | rates_deaths,\n",
    "    title='Cumulative Covid-19 cases and deaths in the U.S.'\n",
    ").configure_title(\n",
    "    anchor='middle'\n",
    ")\n",
    "display(chart)\n",
    "display(html_credits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Total tests and positives per 100k population"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "most_recent_test_date = data_df['date'].max()\n",
    "most_recent_df = data_df[data_df['date'] == most_recent_test_date]\n",
    "print(\"Most recent test date\", most_recent_test_date)\n",
    "print(len(most_recent_df), \"states/territories have data on this date.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "viz_df = most_recent_df.sort_values('tested_100k', ascending=False)\n",
    "chart = alt.Chart(viz_df, title=\"Cases (orange points) and tests(blue bars) per 100k\").encode(alt.X('state', sort=None))\n",
    "tests = chart.mark_bar().encode(alt.Y('tested_100k', axis=alt.Axis(title='COVID-19 Tests/100k, Positive Cases/100k')))\n",
    "positives = chart.mark_point(color='orange', filled=True, size=100, opacity=1).encode(alt.Y('positive_100k'))\n",
    "display(alt.layer(tests, positives))\n",
    "display(html_credits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Counts and rates by state\n",
    "\n",
    "Taking a look at the three states with the highest per-capita incidence of covid-19. The red and yellow curves represent the total tests and total positive tests respectively. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# produce the charts for a few states\n",
    "\n",
    "charts=[]\n",
    "for state in most_recent_df.sort_values('tested_100k', ascending=False)['state'].to_list()[:3]: \n",
    "    state_df = tdf_diff[tdf_diff['state'] == state].copy()\n",
    "\n",
    "    base = alt.Chart(state_df, title=state).encode(alt.X('date', axis=alt.Axis(title='Date'))).properties(width=250, height=150)\n",
    "    dailies = base.mark_bar(size=6).encode(alt.Y('positive_diff', axis=alt.Axis(title='Daily positive')))\n",
    "\n",
    "    totals = base.mark_line(color='red').encode(alt.Y('total_10', axis=alt.Axis(title='Total/10'))) \n",
    "    positives = totals.mark_line(color='orange').encode(alt.Y('positive', axis=alt.Axis(title='Positive')))\n",
    "    cumulative = totals + positives\n",
    "\n",
    "    ratio = base.mark_line(color='red').encode(alt.Y('ratio', axis=alt.Axis(title='Positive/Total'), scale=alt.Scale(domain=(0,1))))\n",
    "    \n",
    "    charts.append(alt.layer(dailies, cumulative).resolve_scale(y='independent'))\n",
    "\n",
    "display(alt.hconcat(*charts))\n",
    "display(html_credits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": true,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}