Skip to content
Snippets Groups Projects
Commit abbcd8c8 authored by Rok Roškar's avatar Rok Roškar Committed by Rok Roškar
Browse files

chore: refactor dataframe and plotting

parent 2a7c31c4
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
This diff is collapsed.
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
def read_jhu_covid_df(ts_folder, name): def read_jhu_covid_df(ts_folder, name):
filename = os.path.join(ts_folder, f"time_series_covid19_{name}_global.csv") filename = os.path.join(ts_folder, f"time_series_covid19_{name}_global.csv")
df = pd.read_csv(filename) df = pd.read_csv(filename)
df = df.set_index(['Province/State', 'Country/Region', 'Lat', 'Long']) df = df.set_index(["Province/State", "Country/Region", "Lat", "Long"])
df.columns = pd.to_datetime(df.columns) df.columns = pd.to_datetime(df.columns)
return df return df
...@@ -27,7 +27,7 @@ def read_jhu_frames_map(ts_folder): ...@@ -27,7 +27,7 @@ def read_jhu_frames_map(ts_folder):
def read_rates_covid_df(rates_folder, name): def read_rates_covid_df(rates_folder, name):
filename = os.path.join(rates_folder, f"ts_rates_19-covid-{name}.csv") filename = os.path.join(rates_folder, f"ts_rates_19-covid-{name}.csv")
df = pd.read_csv(filename).drop("Unnamed: 0", axis=1) df = pd.read_csv(filename).drop("Unnamed: 0", axis=1)
df = df.set_index(['Country/Region']) df = df.set_index(["Country/Region"])
df.columns = pd.to_datetime(df.columns) df.columns = pd.to_datetime(df.columns)
return df return df
...@@ -42,41 +42,56 @@ def read_rates_frames_map(rates_folder): ...@@ -42,41 +42,56 @@ def read_rates_frames_map(rates_folder):
def read_geodata(geodata_path): def read_geodata(geodata_path):
geodata_df = pd.read_csv(geodata_path) geodata_df = pd.read_csv(geodata_path)
geodata_df = geodata_df.drop('Unnamed: 0', axis=1) geodata_df = geodata_df.drop("Unnamed: 0", axis=1)
geodata_df = geodata_df.rename({'name_jhu':'Country/Region'}, axis=1) geodata_df = geodata_df.rename({"name_jhu": "Country/Region"}, axis=1)
geodata_df = geodata_df.set_index('Country/Region') geodata_df = geodata_df.set_index("Country/Region")
return geodata_df return geodata_df
def latest_jhu_country_ser(jhu_frames_map, name): def latest_jhu_country_ser(jhu_frames_map, name):
return jhu_frames_map[name].iloc[:,-1].groupby(level='Country/Region').sum() return jhu_frames_map[name].iloc[:, -1].groupby(level="Country/Region").sum()
def countries_with_number_of_cases(jhu_frames_map, name, count): def countries_with_number_of_cases(jhu_frames_map, name, count):
case_count_ser = latest_jhu_country_ser(jhu_frames_map, 'confirmed') case_count_ser = latest_jhu_country_ser(jhu_frames_map, "confirmed")
countries_over_thresh = case_count_ser[case_count_ser > count - 1].index countries_over_thresh = case_count_ser[case_count_ser > count - 1].index
return countries_over_thresh return countries_over_thresh
def latest_rates_ser(rates_frames_map, name): def latest_rates_ser(rates_frames_map, name):
return rates_frames_map[name].iloc[:,-1] return rates_frames_map[name].iloc[:, -1]
def compute_map_df(rates_frames_map, jhu_frames_map, geodata_df, countries_over_thresh): def compute_map_df(rates_frames_map, jhu_frames_map, geodata_df, countries_over_thresh):
map_df = pd.concat([ map_df = pd.concat(
latest_rates_ser(rates_frames_map, 'confirmed'), [
latest_rates_ser(rates_frames_map, 'deaths')], axis=1) latest_rates_ser(rates_frames_map, "confirmed"),
nominal_df = pd.concat([ latest_rates_ser(rates_frames_map, "deaths"),
latest_jhu_country_ser(jhu_frames_map, 'confirmed'), ],
latest_jhu_country_ser(jhu_frames_map, 'deaths')], axis=1) axis=1,
map_df = pd.concat([map_df, nominal_df, geodata_df[['Longitude', 'Latitude']]], axis=1) )
nominal_df = pd.concat(
[
latest_jhu_country_ser(jhu_frames_map, "confirmed"),
latest_jhu_country_ser(jhu_frames_map, "deaths"),
],
axis=1,
)
map_df = pd.concat(
[map_df, nominal_df, geodata_df[["Longitude", "Latitude"]]], axis=1
)
# Restrict to countries with 100 or more cases # Restrict to countries with 100 or more cases
map_df = map_df.loc[countries_over_thresh].dropna() map_df = map_df.loc[countries_over_thresh].dropna()
map_df = map_df.reset_index() map_df = map_df.reset_index()
map_df.columns = ['Country/Region', map_df.columns = [
'Confirmed/100k', 'Deaths/100k', "Country/Region",
'Confirmed', 'Deaths', "Confirmed/100k",
'Long', 'Lat'] "Deaths/100k",
"Confirmed",
"Deaths",
"Long",
"Lat",
]
return map_df return map_df
...@@ -86,41 +101,114 @@ def map_of_variable(map_df, variable, title): ...@@ -86,41 +101,114 @@ def map_of_variable(map_df, variable, title):
graticule = alt.graticule() graticule = alt.graticule()
# Source of land data # Source of land data
source = alt.topo_feature(data.world_110m.url, 'countries') source = alt.topo_feature(data.world_110m.url, "countries")
# Layering and configuring the components # Layering and configuring the components
p = alt.layer( p = (
alt.Chart(sphere).mark_geoshape(fill='#cae6ef'), alt.layer(
alt.Chart(graticule).mark_geoshape(stroke='white', strokeWidth=0.5), alt.Chart(sphere).mark_geoshape(fill="#cae6ef"),
alt.Chart(source).mark_geoshape(fill='#dddddd', stroke='#aaaaaa'), alt.Chart(graticule).mark_geoshape(stroke="white", strokeWidth=0.5),
alt.Chart(map_df).mark_circle(opacity=0.6).encode( alt.Chart(source).mark_geoshape(fill="#dddddd", stroke="#aaaaaa"),
longitude='Long:Q', alt.Chart(map_df)
latitude='Lat:Q', .mark_circle(opacity=0.6)
size=alt.Size(f'{variable}:Q', title="Cases"), .encode(
color=alt.value('steelblue'), longitude="Long:Q",
tooltip=["Country/Region:N", latitude="Lat:Q",
"Confirmed:Q", "Deaths:Q", size=alt.Size(f"{variable}:Q", title="Cases"),
"Confirmed/100k:Q", "Deaths/100k:Q"] color=alt.value("steelblue"),
tooltip=[
"Country/Region:N",
"Confirmed:Q",
"Deaths:Q",
"Confirmed/100k:Q",
"Deaths/100k:Q",
],
),
) )
).project( .project("naturalEarth1")
'naturalEarth1' .properties(width=600, height=400, title=f"{title} cases per 100k inhabitants")
).properties(width=600, height=400, title=f"{title} cases per 100k inhabitants" .configure_view(stroke=None)
).configure_view(stroke=None) )
return p return p
def growth_df(rates_frames_map, geodata_df, name, countries_over_thresh, cutoff): def growth_df(rates_frames_map, geodata_df, name, countries_over_thresh, cutoff):
latest_confirmed_ser = rates_frames_map['confirmed'].iloc[:,-1] latest_confirmed_ser = rates_frames_map["confirmed"].iloc[:, -1]
countries_over_1 = latest_confirmed_ser[latest_confirmed_ser >= cutoff].reset_index()['Country/Region'] countries_over_1 = latest_confirmed_ser[
latest_confirmed_ser >= cutoff
].reset_index()["Country/Region"]
confirmed_rate_df = rates_frames_map['confirmed'] confirmed_rate_df = rates_frames_map["confirmed"]
confirmed_rate_df = confirmed_rate_df.loc[ confirmed_rate_df = confirmed_rate_df.loc[
confirmed_rate_df.index.isin(countries_over_1) & confirmed_rate_df.index.isin(countries_over_1)
confirmed_rate_df.index.isin(countries_over_thresh)] & confirmed_rate_df.index.isin(countries_over_thresh)
]
confirmed_rate_df = confirmed_rate_df.join( confirmed_rate_df = confirmed_rate_df.join(
geodata_df[['Longitude', 'Latitude', 'region_un']]).set_index( geodata_df[["Longitude", "Latitude", "region_un"]]
['Longitude', 'Latitude', 'region_un'], append=True) ).set_index(["Longitude", "Latitude", "region_un"], append=True)
confirmed_rate_df = confirmed_rate_df.stack().reset_index() confirmed_rate_df = confirmed_rate_df.stack().reset_index()
confirmed_rate_df = confirmed_rate_df.rename( confirmed_rate_df = confirmed_rate_df.rename(
{'region_un': 'Geo Region', 'level_4': 'Date', 0: 'Confirmed/100k'}, axis=1) {"region_un": "Geo Region", "level_4": "Date", 0: "Confirmed/100k"}, axis=1
)
return confirmed_rate_df return confirmed_rate_df
def get_region_populations(country_iso3):
import sys
from SPARQLWrapper import SPARQLWrapper, JSON
endpoint_url = "https://query.wikidata.org/sparql"
query = """
SELECT DISTINCT ?population ?region_iso ?regionLabel
{{
# select country by its iso-3
?country wdt:P298 "{country_iso3}" .
# region has an iso ?iso
?region wdt:P300 ?region_iso .
# region has a population of ?population
?region wdt:P1082 ?population .
# country contains region ?region
?country wdt:P150 ?region .
# country is an instance of sovereign state
?country wdt:P31 wd:Q3624078 .
SERVICE wikibase:label {{ bd:serviceParam wikibase:language "en" }}
}}
"""
def get_results(endpoint_url, query):
user_agent = "WDQS-example Python/%s.%s" % (
sys.version_info[0],
sys.version_info[1],
)
# TODO adjust user agent; see https://w.wiki/CX6
sparql = SPARQLWrapper(endpoint_url, agent=user_agent)
sparql.setQuery(query)
sparql.setReturnFormat(JSON)
return sparql.query().convert()
results = get_results(endpoint_url, query.format(country_iso3=country_iso3))
return {
result["region_iso"]["value"]: int(result["population"]["value"])
for result in results["results"]["bindings"]
}
def make_since_df(df, column="positive", region_column="state", start_case=100):
"""Make dataframe shifted by date of case 100."""
# make dataframe with only points >=100 positives
since_df = df.loc[df[column] >= start_case]
# group since_df dataframe by region and then increasing order of date
since_df = since_df.sort_values(by=[region_column, "date"])
since_df = since_df.reset_index()
# add a column for the number of days since the 100th case for each state
for _, df in since_df.groupby(region_column):
since_df.loc[df.index, "sinceDay0"] = range(0, len(df))
since_df = since_df.astype({"sinceDay0": "int32"})
return since_df
\ No newline at end of file
"""Simple functions to help with common charts."""
import altair as alt
import pandas as pd
def make_rule_chart(
start_case=100,
max_case=10000,
max_days=100,
pos_day=(9, 30000),
pos_3days=(28, 31000),
pos_week=(20, 300),
):
"""
Make dataframe with lines to indicate doubling every day, 3 days, week.
start_case: case starting point
max_case: maxium number of cases to plot
max_days: maximum number of days to show
pos_day: position of the "doubles every day" label
pos_3days: position of "doubles every three days" label
pos_week: position of "doubles every week" label
"""
days = {"day": range(max_days + 1)}
logRuleDay_df = pd.DataFrame(days, columns=["day"])
logRuleDay_df["case"] = start_case * pow(2, logRuleDay_df["day"])
logRuleDay_df["doubling period"] = "every day"
logRule3Days_df = pd.DataFrame(days, columns=["day"])
logRule3Days_df["case"] = start_case * pow(2, (logRule3Days_df["day"]) / 3)
logRule3Days_df["doubling period"] = "three days"
logRuleWeek_df = pd.DataFrame(days, columns=["day"])
logRuleWeek_df["case"] = start_case * pow(2, (logRuleWeek_df["day"]) / 7)
logRuleWeek_df["doubling period"] = "every week"
logRules_df = pd.concat([logRuleDay_df, logRule3Days_df, logRuleWeek_df])
logRules_df = logRules_df.reset_index()
ruleChart = (
alt.Chart(logRules_df)
.mark_line(opacity=0.2, clip=True)
.encode(
alt.X("day:Q", scale=alt.Scale(domain=[1, max_days])),
alt.Y("case", scale=alt.Scale(domain=[start_case, max_case])),
color=alt.Color("doubling period")
)
)
# make dataframe for text labels on chart - hand edit these label locations
textLabels_df = pd.DataFrame(
[
[*pos_day, "doubles every day"],
[*pos_3days, "doubles every 3 days"],
[*pos_week, "doubles every week"],
],
columns=["labelX", "labelY", "labelText"],
)
labelChart = (
alt.Chart(textLabels_df)
.mark_text(align="right", baseline="bottom", dx=0, size=15, opacity=0.5)
.encode(x="labelX", y="labelY", text="labelText")
)
return ruleChart + labelChart
def generate_region_chart(
base,
column,
region_column,
ytitle,
tooltip_title,
legend_title=None,
scale_type="linear",
):
"""
Produce a regional chart given a column name.
base: chart base - create with alt.Chart()
column: name of column to use for y-axis
region_column: column to use for regions
ytitle: name of y-axis
tooltip_title: tooltip title for y-axis values
legend_title: legend title
scale_type: "linear" or "log"
"""
if legend_title is None:
legend_title = region_column
color = alt.Color(region_column, legend=alt.Legend(title=legend_title))
selection = alt.selection_multi(fields=[region_column], bind="legend")
opacity = alt.condition(selection, alt.value(1), alt.value(0.2))
tooltip = [
alt.Tooltip(region_column, title=legend_title),
alt.Tooltip(column, title=tooltip_title),
alt.Tooltip("date", title="Date"),
]
chart = (
base.mark_line()
.encode(
alt.X("date", title="Date"),
alt.Y(column, title=ytitle, scale=alt.Scale(type=scale_type)),
color=color,
opacity=opacity,
tooltip=tooltip,
)
.add_selection(selection)
)
return chart
def make_region_since_chart(
base,
column,
time_column,
region_column,
xtitle,
ytitle,
tooltip_title,
legend_title,
):
"""
Make chart "days since X case".
base: chart base - create with alt.Chart()
column: name of column to use for y-axis
time_column: column that encodes the time since X case
region_column: column to use for regions
xtitle: as it says on the tin
ytitle: as it says on the tin
tooltip_title: as it says on the tin
legend_title: as it says on the tin
scale_type: "linear" or "log"
"""
selection = alt.selection_multi(fields=[region_column], bind='legend')
opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
lineChart = base.mark_line().encode(
alt.X(time_column,
axis=alt.Axis(title=xtitle)
),
alt.Y(column,
axis = alt.Axis(title=ytitle),
scale=alt.Scale(type='log'),
),
tooltip=[alt.Tooltip(region_column,title=legend_title),
alt.Tooltip(column,title=tooltip_title),
alt.Tooltip('date',title='Date')],
color = alt.Color(region_column, legend=alt.Legend(title=legend_title)),
opacity=opacity
).add_selection(
selection
)
return lineChart
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment