"""Plot collection classes."""
from importlib import import_module
import numpy as np
import xarray as xr
from arviz.sel_utils import xarray_sel_iter
def sel_subset(sel, present_dims):
return {key: value for key, value in sel.items() if key in present_dims}
def subset_ds(ds, var_name, sel):
out = ds[var_name].sel(sel_subset(sel, ds[var_name].dims))
return out.item()
[docs]class PlotCollection:
[docs] def __init__(self, data, viz_ds, aes=None, backend=None, **kwargs):
self.data = data
self.preprocessed_data = None
self.viz = viz_ds
self.ds = xr.Dataset()
if backend is not None:
self.backend = backend
if aes is None:
aes = {}
for aes_key, dims in aes.items():
aes_raw_shape = [len(data[dim]) for dim in dims]
n_aes = np.prod(aes_raw_shape)
aes_vals = kwargs.get(aes_key, [None])
n_aes_vals = len(aes_vals)
if n_aes_vals > n_aes:
aes_vals = aes_vals[:n_aes]
elif n_aes_vals < n_aes:
aes_vals = np.tile(aes_vals, (n_aes // n_aes_vals) + 1)[:n_aes]
self.ds[aes_key] = xr.DataArray(
np.array(aes_vals).reshape(aes_raw_shape),
dims=dims,
coords={dim: data.coords[dim] for dim in dims if dim in data.coords},
)
self.aes = aes
@property
def base_loop_dims(self):
return set(self.viz["plot"].dims)
@classmethod
def wrap(
cls,
data,
cols=None,
col_wrap=4,
backend="matplotlib",
plot_grid_kws=None,
**kwargs,
):
if plot_grid_kws is None:
plot_grid_kws = {}
if cols is None:
plots_raw_shape = ()
n_plots = 1
else:
plots_raw_shape = [len(data[col]) for col in cols]
n_plots = np.prod(plots_raw_shape)
if n_plots <= col_wrap:
n_rows, n_cols = 1, n_plots
else:
n_rows = n_plots // col_wrap + 1
n_cols = col_wrap
plot_bknd = import_module(f".backend.{backend}", package="xrtist")
fig, ax_ary = plot_bknd.create_plotting_grid(n_plots, n_rows, n_cols, **plot_grid_kws)
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
if n_plots > 1:
viz_ds = xr.Dataset(
{
"chart": fig,
"plot": (cols, ax_ary.flatten()[:n_plots].reshape(plots_raw_shape)),
"row": (cols, row_id.flatten()[:n_plots].reshape(plots_raw_shape)),
"col": (cols, col_id.flatten()[:n_plots].reshape(plots_raw_shape)),
},
coords={col: data[col] for col in cols},
)
else:
viz_ds = xr.Dataset(
{
"chart": fig,
"plot": ax_ary,
"row": 0,
"col": 0,
},
)
return cls(data, viz_ds, backend=backend, **kwargs)
@classmethod
def grid(
cls,
data,
cols,
rows,
backend="matplotlib",
plot_grid_kws=None,
**kwargs,
):
if plot_grid_kws is None:
plot_grid_kws = {}
n_cols = np.prod([len(data[col]) for col in cols])
n_rows = np.prod([len(data[row]) for row in rows])
n_plots = n_cols * n_rows
plot_bknd = import_module(f".backend.{backend}", package="xrtist")
fig, ax_ary = plot_bknd.create_plotting_grid(n_plots, n_rows, n_cols, **plot_grid_kws)
dims = tuple((*rows, *cols)) # use provided dim orders, not existing ones
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
plots_raw_shape = [len(data[dim]) for dim in dims]
viz_ds = xr.Dataset(
{
"chart": fig,
"plot": (dims, ax_ary.flatten().reshape(plots_raw_shape)),
"row": (dims, row_id.flatten().reshape(plots_raw_shape)),
"col": (dims, col_id.flatten().reshape(plots_raw_shape)),
},
coords={dim: data[dim] for dim in dims},
)
return cls(data, viz_ds, backend=backend, **kwargs)
def _update_aes(self, ignore_aes=frozenset()):
aes = [aes_key for aes_key in self.aes.keys() if aes_key not in ignore_aes]
aes_dims = [dim for aes_key in aes for dim in self.aes[aes_key]]
all_loop_dims = self.base_loop_dims.union(aes_dims)
return aes, all_loop_dims
def plot_iterator(self, ignore_aes=frozenset()):
_, all_loop_dims = self._update_aes(ignore_aes)
plotters = xarray_sel_iter(
self.data, skip_dims={dim for dim in self.data.dims if dim not in all_loop_dims}
)
for var_name, sel, isel in plotters:
yield var_name, sel, isel
def map(
self,
fun,
fun_label=None,
*,
ignore_aes=frozenset(),
preprocessed=False,
subset_info=False,
store_artist=True,
**kwargs,
):
aes, all_loop_dims = self._update_aes(ignore_aes)
plotters = xarray_sel_iter(
self.data, skip_dims={dim for dim in self.data.dims if dim not in all_loop_dims}
)
artist_dims = [dim for dim in self.data.dims if dim in all_loop_dims]
artist_shape = [len(self.data[dim]) for dim in artist_dims]
if fun_label is None:
fun_label = fun.__name__
if store_artist:
self.viz[fun_label] = xr.DataArray(
np.empty(artist_shape, dtype=object),
dims=artist_dims,
coords={dim: self.data[dim] for dim in artist_dims},
)
for var_name, sel, isel in plotters:
da = self.data.sel(sel)
target = subset_ds(self.viz, "plot", sel)
aes_kwargs = {}
for aes_key in aes:
aes_kwargs[aes_key] = subset_ds(self.ds, aes_key, sel)
fun_kwargs = {**aes_kwargs, **kwargs}
fun_kwargs["backend"] = self.backend
if preprocessed:
if self.preprocessed_data is None:
raise ValueError(
"You must manually set the `preprocessed_data` to use preprocessed=True"
)
pre_da = self.preprocessed_data.sel(sel_subset(sel, self.preprocessed_data.dims))
fun_kwargs["preprocessed_data"] = pre_da
if subset_info:
fun_kwargs = {**fun_kwargs, "var_name": var_name, "sel": sel, "isel": isel}
aux_artist = fun(da, target=target, **fun_kwargs)
if store_artist:
self.viz[fun_label].loc[sel] = aux_artist
def add_legend(self, aes, artist, **kwargs):
pass