Data organizers: proof of concept#

This notebook has an overview of refactoring ArviZ plotting ideas. It only has proof of concept very rough implementations for now. One of the main goals is modularization, so a great deal of these ideas could be used by raw xarray.

It assumes you are already have some familiarity with ArviZ plotting.

import arviz as az
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
xr.set_options(display_expand_data=False);
az.style.use("arviz-white")
idata = az.load_arviz_data("rugby")
post = idata.posterior

draft proposed extensions#

I propose to create a class similar to facetgrid in xarray or in old seaborn.

Important

For now this works on DataArrays only!! At this point in time, this is a proof of concept we can use to discuss API and ensure I am moving in the right direction. Once we are satisfied with this, I’ll work on having this work with Datasets and iterables of Dataset (for plotting, dataset and inferencedata are equivalent in terms of plotting, we will never want to loop over groups to plot, only have info from multiple groups).

from xrtist import PlotCollection

A PlotCollection should be initialized or reused for each plot. It can be initialized with .grid or .wrap classmethods. At init time it creates two attributes mostly: .ds and .viz.

  • .ds is a dataset that stores the aesthetics mapping. It is used so that when we loop, we subset that dataset with the same sel used to subset the data and get the right aesthetics for that plot.

  • .viz is used to store all the elements of the generated graph. At init time it only has the figure and axes. Then when plotting artists are also stored there.

The main method of PlotCollection is .map. .map loops over each of the data subsets defined in the PlotCollection and calls a plottting function. These plotting functions should take a dataarray of the data to plot and axes and return a single artist/glyph. It is possible to skip returning the artist, but it should be done in special cases only.

Let’s see in in practice!

We start defining a function roughly equivalent to plot_kde but that follows the rules explained above.

def kde_artist(values, target, **kwargs):
    kwargs.pop("backend")
    grid, pdf = az.kde(np.array(values).flatten())
    return target.plot(grid, pdf, **kwargs)[0]

Base example: plot kdes for each of the 6 teams#

We start with a basic example, plotting a kde for each of the teams. In this case, we only want to generate a subplot for each team, no aesthetics or anything else. We indicate this with col=["team"]. As we are using .wrap instead of row of 6 plots we get 2 rows of 3 plots. A .grid method is already available and something for plot_pair like plots will be added, be a .pair or .matrix classmethod or a PlotMatrix class with similar api.

pc = PlotCollection.wrap(
    post["atts"], 
    cols=["team"], 
    subplot_kws={"figsize": (12, 8)}
)
pc.map(kde_artist, "kde")
_images/0afdc2f798234967b566d2db7b73c741577cddc8b9a20f50f6aa42c044256e1e.png

Add aesthetics: plot kdes for each team in different subplots, overlaying the different chains with different colors.#

We now use the aes argument to have each chain have a different color.

Note: not all plots will support all aesthetics.

pc = PlotCollection.wrap(
    post["atts"], 
    cols=["team"], 
    aes={"color": ["chain"]}, 
    color=[f"C{i}" for i in range(4)],
    subplot_kws={"figsize": (12, 8)}
)
pc.viz
pc.map(kde_artist, "kde")
_images/3c49e65e71fbadeddbac32839f5f79edcc4d4ff4649d64a8ff389eef91279752.png

The figure (chart), axes (plot), row and col indexes (row, col) and generated artists are stored in an xarray dataset:

pc.viz
<xarray.Dataset>
Dimensions:  (team: 6, chain: 4)
Coordinates:
  * team     (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
  * chain    (chain) int64 0 1 2 3
Data variables:
    chart    object Figure(432x288)
    plot     (team) object AxesSubplot(0.0558455,0.562363;0.184204x0.409583) ...
    row      (team) int64 0 0 0 0 1 1
    col      (team) int64 0 1 2 3 0 1
    kde      (chain, team) object Line2D(_child0) ... Line2D(_child3)

Same for aesthetics

pc.ds
<xarray.Dataset>
Dimensions:  (chain: 4)
Coordinates:
  * chain    (chain) int64 0 1 2 3
Data variables:
    color    (chain) <U2 'C0' 'C1' 'C2' 'C3'

Multiple aesthetics for the same dimension#

We can set multiple asethetics to the same dimension, and both will be updated in sync in all plots:

pc = PlotCollection.wrap(
    post["atts"], 
    cols=["team"], 
    aes={"color": ["chain"], "ls": ["chain"], "lw": ["team"]}, 
    color=[f"C{i}" for i in range(4)],
    ls=["-", ":"],
    lw=np.linspace(1, 3, 6),
    subplot_kws={"figsize": (12, 8)}
)
pc.map(kde_artist)
_images/8746436b3a1a171f8006034612aa4ffa89c009bf66457b116f9e81700410aa99.png

Easy access to all plots elements#

And as we have stored all the plotting related data, including artists in the .viz attribute, we can select artists from the plot and modify them with label based indexing. Here we can make the line for the Italy team and 3rd chain have diamonds as markers after plotting, with size 10 and show markers only once every 50 datapoints.

Note: I’d also like to add an extra coordinate, grid or something of the sort, that can be used to select axes by their position in the subplot grid.

pc = PlotCollection.wrap(
    post["atts"], 
    cols=["team"], 
    aes={"color": ["chain"], "ls": ["chain"], "lw": ["team"]}, 
    color=[f"C{i}" for i in range(4)],
    ls=["-", ":"],
    lw=np.linspace(1, 3, 6),
    subplot_kws={"figsize": (12, 8)}
)
pc.map(kde_artist, "kde")
pc.viz["kde"].sel(team="Italy", chain=2).item().set(marker="D", markersize=10, markevery=50);
_images/84e112705d2bd43661bdbc8a0e9d1f4bbd6310ed499c3da2cec81e3e9f2c7a3c.png

We can change the layout of the plots and not need to change the subsetting and modification of the specific plot:

pc = PlotCollection.wrap(
    post["atts"], 
    cols=["team"], 
    col_wrap=3, # <-- change here!!
    aes={"color": ["chain"], "ls": ["chain"], "lw": ["team"]}, 
    color=[f"C{i}" for i in range(4)],
    ls=["-", ":"],
    lw=np.linspace(1, 3, 6),
    subplot_kws={"figsize": (12, 8)}
)
pc.map(kde_artist, "kde")
pc.viz["kde"].sel(team="Italy", chain=2).item().set(marker="D", markersize=10, markevery=50);
_images/250849e43f79e600db2fb1e93a2865ce0c705216638bd3118ff2f6c291b27299.png

And unlike xarray facetting, all asethetics and facetting variables take lists of dimensions, not single dimensions

pc = PlotCollection.wrap(
    post["atts"], 
    cols=["team", "chain"],
    col_wrap=6,
    aes={"color": ["chain"]}, 
    color=[f"C{i}" for i in range(4)],
    subplot_kws={"figsize": (12, 8)}
)
pc.map(kde_artist)
_images/fc73fe2f7666299228b74d736270ce0625427b6aa58515b3c9deade0aa2b07fe.png

the plot variable contines to have both team and chain dimensions, that do not match the grid ones. This is ugly but it is being done to showcase this feature. Here we don’t have too many dimensions, so it looks better to facet with row and col with one variable each. We can use the .grid classmethod:

pc.viz["plot"].where(pc.viz["col"] == 0, drop=True)
<xarray.DataArray 'plot' (team: 4, chain: 2)>
AxesSubplot(0.0558455,0.851452;0.0674711x0.11964) ... AxesSubplot(0.0558455,0...
Coordinates:
  * team     (team) object 'Wales' 'France' 'Scotland' 'Italy'
  * chain    (chain) int64 0 2
pc = PlotCollection.grid(
    post["atts"], 
    cols=["team"],
    rows=["chain"],
    aes={"color": ["chain"]}, 
    color=[f"C{i}" for i in range(4)],
    subplot_kws={"figsize": (12, 8)}
)
pc.map(kde_artist)
for ax in pc.viz["plot"].sel(chain=0):
    ax.item().set_title(ax["team"].item())


for ax in pc.viz["plot"].isel(team=0):
    ax.item().set_ylabel(f"""chain = {ax["chain"].item()}""")
_images/10aaef49ee4f0ad9987bd1c10b408a5966ec85c0345912e41327ca3e3887eab7.png

Mimic existing ArviZ functions#

plot_posterior#

We will now mimic plot_posterior. We use functions from the visual module, which generate artists like we did for kde function. We add a couple extras here which are more aesthetic, so they are hidden to not interfere with reading.

Hide code cell source
def title_artist(values, target, var_name, sel, isel, labeller_fun, **kwargs):
    kwargs.pop("backend")
    label = labeller_fun(var_name, sel, isel)
    return target.set_title(label, **kwargs)    

def remove_left_axis(values, target, **kwargs):
    kwargs.pop("backend")
    target.get_yaxis().set_visible(False)
    target.spines['left'].set_visible(False)

We now mimic plot_posterior (dataarray input only). Plot posterior takes an inferencedata and automatically generates the axes and plots everything. With PlotCollection this now becomes creating a plot collection (unless one is provided) and then calling .map multiple times to add all the elements to the plot: kde, hdi interval line, point estimate marker, point estimate text…

from xrtist import visuals

def plot_posterior(da, plot_collection=None, labeller=None):
    if plot_collection is None:
        plot_collection = PlotCollection.wrap(da, cols=[dim for dim in da.dims if dim not in {"chain", "draw"}])
    if labeller is None:
        labeller = az.labels.NoVarLabeller()
        
    labeller_fun = labeller.make_label_vert
    
    plot_collection.map(visuals.kde, "kde")
    plot_collection.map(visuals.interval, "interval", color="grey")    
    plot_collection.map(visuals.point, "point_estimate", color="C0", size=25, marker="o")
    plot_collection.map(visuals.point_label, "point_label", color="C0", va="bottom", ha="center")
    plot_collection.map(remove_left_axis, store_artist=False)
    plot_collection.map(title_artist, "title", subset_info=True, labeller_fun=labeller_fun)
    return plot_collection

Here it can be seen how the usage is the same basically as our current API, but if we also expose the artist function, it will be much easier for users to create their own variations of plot_posterior or to extend it with their own artist functions.

import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

az.style.use("arviz-doc")
pc = plot_posterior(post["atts"])
_images/11bab8f064df6e6697682376ac5ca162e51acaa61fb2dcbbbd1881a5967af3d9.png

Finally, we show how all the elements added to the plot have had their corresponding artists added to the .viz dataset.

pc.viz
<xarray.Dataset>
Dimensions:         (team: 6)
Coordinates:
  * team            (team) object 'Wales' 'France' ... 'Italy' 'England'
Data variables:
    chart           object Figure(3450x1500)
    plot            (team) object AxesSubplot(0.00362348,0.565778;0.237246x0....
    row             (team) int64 0 0 0 0 1 1
    col             (team) int64 0 1 2 3 0 1
    kde             (team) object Line2D(_child0) ... Line2D(_child0)
    interval        (team) object Line2D(_child1) ... Line2D(_child1)
    point_estimate  (team) object <matplotlib.collections.PathCollection obje...
    point_label     (team) object Text(0.17110852459382564, 0.470849748773512...
    title           (team) object Text(0.5, 1.0, 'Wales') ... Text(0.5, 1.0, ...

Use preprocessed computations#

Using these independent artist functions is convenient from a plotting perspective, but computationally it is much more demanding. We need to compute the point estimate for example in both the marker and the label artist, and in fact, to make the point estimate label look nice, we place it at 0.05 the height of the kde (with the marker at 0), so we also need to compute the kde twice (or 3 times if we want to use a similar trick for the hdi label).

The overall idea was (and is) to split data organization, processing and plotting into pieces as independent as possible. (keep in mind this is only a data organization proof of concept). Here we pre compute all the quantities we will need for plot posterior, and .map itself is the one that subsets this preprocessed data into subsets aligned with the raw data. Thus, if we use preprocessed=True, in each iteration the artist function gets both the raw data and the pre-processed data properly subsetted. That is, we now compute hdi, the point estimate and the kde only once.

Note: this avenue also has a lot of potential for further speed-ups. As computing the kde is no longer done within the plotting loop (as it currently happens), we can parallelize the computation using one core per team for example. Then keep the looping for the plot only.

import xrtist.processing

def plot_posterior_pre(da, plot_collection=None, labeller=None):
    if plot_collection is None:
        plot_collection = PlotCollection.wrap(da, cols=[dim for dim in da.dims if dim not in {"chain", "draw"}])
        
    # pre compute all plot elements: hdi, mean and kde and add them to plot_collection
    pre_ds = xr.Dataset()
    pre_ds["interval"] = az.hdi(da)[da.name]
    pre_ds["point_estimate"] = da.mean(("chain", "draw"))
    grid, kde = xrtist.processing.kde(da)
    pre_ds["grid"] = grid
    pre_ds["kde"] = kde
    plot_collection.preprocessed_data = pre_ds
    
    if labeller is None:
        labeller = az.labels.NoVarLabeller()
    labeller_fun = labeller.make_label_vert
    
    # add all drawings to the plot with no computation inside artist functions anymore, only plotting
    plot_collection.map(visuals.kde, "kde", preprocessed=True)
    plot_collection.map(visuals.interval, "interval", preprocessed=True, color="grey")    
    plot_collection.map(visuals.point, "point_estimate", preprocessed=True, color="C0", size=25, marker="o")
    plot_collection.map(
        visuals.point_label, "point_label", preprocessed=True, point_label="mean",
        color="C0", va="bottom", ha="center"
    )
    plot_collection.map(remove_left_axis, store_artist=False)
    plot_collection.map(title_artist, "title", subset_info=True, labeller_fun=labeller_fun)
    return plot_collection
pc = plot_posterior_pre(post["atts"])
_images/11bab8f064df6e6697682376ac5ca162e51acaa61fb2dcbbbd1881a5967af3d9.png

Pre-processed data is made available to the plot collection by setting the .preprocessed_data attribute which is also available later on in the returned plot collection. So potentially, all the preprocessed data could also be used in other plots (provided facetting and aesthetics are aligned!)

pc.preprocessed_data
<xarray.Dataset>
Dimensions:         (team: 6, hdi: 2, x_kde: 512)
Coordinates:
  * team            (team) object 'Wales' 'France' ... 'Italy' 'England'
  * hdi             (hdi) <U6 'lower' 'higher'
Dimensions without coordinates: x_kde
Data variables:
    interval        (team, hdi) float64 0.09464 0.2514 -0.1641 ... 0.1793 0.3376
    point_estimate  (team) float64 0.1711 -0.08273 0.1085 -0.1165 -0.337 0.2566
    grid            (team, x_kde) float64 0.02241 0.02301 ... 0.437 0.4376
    kde             (team, x_kde) float64 0.06984 0.07005 ... 0.04191 0.0422

plot_trace#

Even with very preliminary proof of concept like this, it is not excessively difficult to mimic more complicated facetting patterns like plot_trace:

from copy import copy
pc = PlotCollection.grid(
    post["atts"].expand_dims(column=2), # add dummy dim for the extra 2 col facetting
    rows=["team"],
    cols=["column"],
    aes={"color": ["chain"]}, 
    color=[f"C{i}" for i in range(4)],
    subplot_kws={"figsize": (12, 8)}
)
pc_right = copy(pc)
pc_right.data = pc.data.isel(column=1)
pc_right.viz = pc.viz.isel(column=1)
pc_right.map(
    lambda values, target, backend, **kwargs: target.plot(
        np.arange(len(values)), values, **kwargs
    )[0]
)
pc_left = copy(pc)
pc_left.data = pc.data.isel(column=0)
pc_left.viz = pc.viz.isel(column=0)
pc_left.map(kde_artist)
_images/f0fab77155e2dbd38d1171a9f1e2b79f17743873f8aba96ef7d2d348f488b08d.png

and compact plot_trace:

pc = PlotCollection.grid(
    post["atts"].expand_dims(column=2), # add dummy dim for the extra 2 col facetting
    rows=["team"],
    cols=["column"],
    aes={"color": ["chain"]}, 
    color=[f"C{i}" for i in range(4)],
    subplot_kws={"figsize": (12, 8)}
)
pc_right = copy(pc)
pc_right.data = pc.data.isel(column=1)
pc_right.viz = pc.viz.isel(column=1)
pc_right.map(
    lambda values, target, backend, **kwargs: target.plot(
        np.arange(len(values)), values, **kwargs
    )[0]
)
pc_left = copy(pc)
pc_left.data = pc.data.isel(column=0)
pc_left.viz = pc.viz.isel(column=0)
pc_left.map(kde_artist, ignore_aes={"color"})
_images/3a2108c64dd25e5a51eeb2e1d7b35e29085d000cf592e20b7b98c092ac2b01f2.png

plot_forest#

# main trick for now is defining the y positions manually beforehand
y = np.arange(6)[:, None] + np.linspace(-.2, .2, 4)[None, :]

pc = PlotCollection.wrap(
    post["atts"],
    aes={"color": ["chain"], "y": ["team", "chain"]},
    y=y.flatten(),
    color=[f"C{i}" for i in range(4)],
    subplot_kws={"figsize": (12, 8)}
)
pc.map(visuals.interval, "hdi", linewidth=1)
pc.map(
    visuals.interval, 
    "quartile_range", 
    linewidth=4, 
    interval_func=lambda values: np.quantile(values, (.25, .75))
)
pc.map(visuals.point, "mean", color="black", size=20, zorder=2)
pc.viz["plot"].item().set_yticks(np.arange(6), post["team"].values);
_images/5cc31a6dbfe0f9fcb56cd1aa83cfa3a28f7c824470f1a0ddd599dfd4576122c6.png

Legends will come soon.