"""Plot collection classes."""
from importlib import import_module
import numpy as np
import xarray as xr
from arviz.sel_utils import xarray_sel_iter
from datatree import DataTree
def sel_subset(sel, present_dims):
return {key: value for key, value in sel.items() if key in present_dims}
def subset_ds(ds, var_name, sel):
subset_dict = sel_subset(sel, ds[var_name].dims)
if subset_dict:
out = ds[var_name].sel(subset_dict)
else:
out = ds[var_name]
if out.size == 1:
return out.item()
return out.values
def _process_facet_dims(data, facet_dims):
if not facet_dims:
return 1, {}
facets_per_var = {}
if "__variable__" in facet_dims:
for var_name, da in data.items():
lenghts = [len(da[dim]) for dim in facet_dims if dim in da.dims]
facets_per_var[var_name] = np.prod(lenghts) if lenghts else 1
n_facets = np.sum(list(facets_per_var.values()))
else:
missing_dims = {
var_name: [dim for dim in facet_dims if dim not in da.dims]
for var_name, da in data.items()
}
missing_dims = {k: v for k, v in missing_dims.items() if v}
if any(missing_dims.values()):
raise ValueError(
"All variables must have all facetting dimensions, but found the following "
f"dims to be missing in these variables: {missing_dims}"
)
n_facets = np.prod([data.sizes[dim] for dim in facet_dims])
return n_facets, facets_per_var
[docs]class PlotCollection:
[docs] def __init__(self, data, viz_ds, aes=None, backend=None, **kwargs):
self.data = data
self.preprocessed_data = None
self.viz = viz_ds
self.ds = xr.Dataset()
if backend is not None:
self.backend = backend
if aes is None:
aes = {}
for aes_key, dims in aes.items():
aes_raw_shape = [len(data[dim]) for dim in dims]
n_aes = np.prod(aes_raw_shape)
aes_vals = kwargs.get(aes_key, [None])
n_aes_vals = len(aes_vals)
if n_aes_vals > n_aes:
aes_vals = aes_vals[:n_aes]
elif n_aes_vals < n_aes:
aes_vals = np.tile(aes_vals, (n_aes // n_aes_vals) + 1)[:n_aes]
self.ds[aes_key] = xr.DataArray(
np.array(aes_vals).reshape(aes_raw_shape),
dims=dims,
coords={dim: data.coords[dim] for dim in dims if dim in data.coords},
)
self.aes = aes
@property
def base_loop_dims(self):
return set(self.viz["plot"].dims)
@classmethod
def wrap(
cls,
data,
cols=None,
col_wrap=4,
backend="matplotlib",
plot_grid_kws=None,
**kwargs,
):
if plot_grid_kws is None:
plot_grid_kws = {}
if cols is None:
plots_raw_shape = ()
n_plots = 1
else:
plots_raw_shape = [len(data[col]) for col in cols]
n_plots = np.prod(plots_raw_shape)
if n_plots <= col_wrap:
n_rows, n_cols = 1, n_plots
else:
n_rows = n_plots // col_wrap + 1
n_cols = col_wrap
plot_bknd = import_module(f".backend.{backend}", package="xrtist")
fig, ax_ary = plot_bknd.create_plotting_grid(n_plots, n_rows, n_cols, **plot_grid_kws)
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
if n_plots > 1:
viz_ds = xr.Dataset(
{
"chart": fig,
"plot": (cols, ax_ary.flatten()[:n_plots].reshape(plots_raw_shape)),
"row": (cols, row_id.flatten()[:n_plots].reshape(plots_raw_shape)),
"col": (cols, col_id.flatten()[:n_plots].reshape(plots_raw_shape)),
},
coords={col: data[col] for col in cols},
)
else:
viz_ds = xr.Dataset(
{
"chart": fig,
"plot": ax_ary,
"row": 0,
"col": 0,
},
)
return cls(data, viz_ds, backend=backend, **kwargs)
@classmethod
def grid(
cls,
data,
cols,
rows,
backend="matplotlib",
plot_grid_kws=None,
**kwargs,
):
if plot_grid_kws is None:
plot_grid_kws = {}
n_cols = np.prod([data.sizes[col] for col in cols])
n_rows = np.prod([data.sizes[row] for row in rows])
n_plots = n_cols * n_rows
plot_bknd = import_module(f".backend.{backend}", package="xrtist")
fig, ax_ary = plot_bknd.create_plotting_grid(n_plots, n_rows, n_cols, **plot_grid_kws)
dims = tuple((*rows, *cols)) # use provided dim orders, not existing ones
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
plots_raw_shape = [len(data[dim]) for dim in dims]
viz_ds = xr.Dataset(
{
"chart": fig,
"plot": (dims, ax_ary.flatten().reshape(plots_raw_shape)),
"row": (dims, row_id.flatten().reshape(plots_raw_shape)),
"col": (dims, col_id.flatten().reshape(plots_raw_shape)),
},
coords={dim: data[dim] for dim in dims},
)
return cls(data, viz_ds, backend=backend, **kwargs)
def _update_aes(self, ignore_aes=frozenset()):
aes = [aes_key for aes_key in self.aes.keys() if aes_key not in ignore_aes]
aes_dims = [dim for aes_key in aes for dim in self.aes[aes_key]]
all_loop_dims = self.base_loop_dims.union(aes_dims)
return aes, all_loop_dims
def plot_iterator(self, ignore_aes=frozenset()):
_, all_loop_dims = self._update_aes(ignore_aes)
plotters = xarray_sel_iter(
self.data, skip_dims={dim for dim in self.data.dims if dim not in all_loop_dims}
)
for var_name, sel, isel in plotters:
yield var_name, sel, isel
def map(
self,
fun,
fun_label=None,
*,
ignore_aes=frozenset(),
preprocessed=False,
subset_info=False,
store_artist=True,
**kwargs,
):
aes, all_loop_dims = self._update_aes(ignore_aes)
plotters = xarray_sel_iter(
self.data, skip_dims={dim for dim in self.data.dims if dim not in all_loop_dims}
)
artist_dims = [dim for dim in self.data.dims if dim in all_loop_dims]
artist_shape = [len(self.data[dim]) for dim in artist_dims]
if fun_label is None:
fun_label = fun.__name__
if store_artist:
self.viz[fun_label] = xr.DataArray(
np.empty(artist_shape, dtype=object),
dims=artist_dims,
coords={dim: self.data[dim] for dim in artist_dims},
)
for var_name, sel, isel in plotters:
da = self.data.sel(sel)
target = subset_ds(self.viz, "plot", sel)
aes_kwargs = {}
for aes_key in aes:
aes_kwargs[aes_key] = subset_ds(self.ds, aes_key, sel)
fun_kwargs = {**aes_kwargs, **kwargs}
fun_kwargs["backend"] = self.backend
if preprocessed:
if self.preprocessed_data is None:
raise ValueError(
"You must manually set the `preprocessed_data` to use preprocessed=True"
)
pre_da = self.preprocessed_data.sel(sel_subset(sel, self.preprocessed_data.dims))
fun_kwargs["preprocessed_data"] = pre_da
if subset_info:
fun_kwargs = {**fun_kwargs, "var_name": var_name, "sel": sel, "isel": isel}
aux_artist = fun(da, target=target, **fun_kwargs)
if store_artist:
self.viz[fun_label].loc[sel] = aux_artist
def add_legend(self, aes, artist, **kwargs):
pass
[docs]class PlotMuseum:
[docs] def __init__(self, data, viz_dt, aes_dt=None, aes=None, backend=None, **kwargs):
self.data = data
self.preprocessed_data = None
self.viz = viz_dt
self.dt = aes_dt
if backend is not None:
self.backend = backend
if aes is None:
aes = {}
self._aes = aes
self._kwargs = kwargs
def generate_aes_dt(self, aes, **kwargs):
if aes is None:
aes = {}
self._aes = aes
self._kwargs = kwargs
self.dt = DataTree()
for var_name, da in self.data.items():
ds = xr.Dataset()
for aes_key, dims in aes.items():
aes_vals = kwargs.get(aes_key, [None])
aes_raw_shape = [da.sizes[dim] for dim in dims if dim in da.dims]
if not aes_raw_shape:
ds[aes_key] = aes_vals[0]
continue
n_aes = np.prod(aes_raw_shape)
n_aes_vals = len(aes_vals)
if n_aes_vals > n_aes:
aes_vals = aes_vals[:n_aes]
elif n_aes_vals < n_aes:
aes_vals = np.tile(aes_vals, (n_aes // n_aes_vals) + 1)[:n_aes]
ds[aes_key] = xr.DataArray(
np.array(aes_vals).reshape(aes_raw_shape),
dims=dims,
coords={dim: da.coords[dim] for dim in dims if dim in da.coords},
)
DataTree(name=var_name, parent=self.dt, data=ds)
@property
def base_loop_dims(self):
if "plot" in self.viz.data_vars:
return set(self.viz["plot"].dims)
return set(dim for da in self.viz.children.values() for dim in da["plot"].dims)
def get_viz(self, var_name):
return self.viz if "plot" in self.viz.data_vars else self.viz[var_name]
@classmethod
def wrap(
cls,
data,
cols=None,
col_wrap=4,
backend="matplotlib",
plot_grid_kws=None,
**kwargs,
):
if cols is None:
cols = []
if plot_grid_kws is None:
plot_grid_kws = {}
n_plots, plots_per_var = _process_facet_dims(data, cols)
if n_plots <= col_wrap:
n_rows, n_cols = 1, n_plots
else:
div_mod = divmod(n_plots, col_wrap)
n_rows = div_mod[0] + (div_mod[1] != 0)
n_cols = col_wrap
plot_bknd = import_module(f".backend.{backend}", package="xrtist")
fig, ax_ary = plot_bknd.create_plotting_grid(
n_plots, n_rows, n_cols, squeeze=False, **plot_grid_kws
)
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
viz_dict = {}
flat_ax_ary = ax_ary.flatten()[:n_plots]
flat_row_id = row_id.flatten()[:n_plots]
flat_col_id = col_id.flatten()[:n_plots]
if "__variable__" not in cols:
dims = cols # use provided dim orders, not existing ones
plots_raw_shape = [data.sizes[dim] for dim in dims]
viz_dict["/"] = xr.Dataset(
{
"chart": fig,
"plot": (dims, flat_ax_ary.reshape(plots_raw_shape)),
"row": (dims, flat_row_id.reshape(plots_raw_shape)),
"col": (dims, flat_col_id.reshape(plots_raw_shape)),
},
coords={dim: data[dim] for dim in dims},
)
else:
viz_dict["/"] = xr.Dataset({"chart": xr.DataArray(fig)})
all_dims = cols
facet_cumulative = 0
for var_name, da in data.items():
dims = [dim for dim in all_dims if dim in da.dims]
plots_raw_shape = [data.sizes[dim] for dim in dims]
col_slice = (
slice(None, None)
if var_name not in plots_per_var
else slice(facet_cumulative, facet_cumulative + plots_per_var[var_name])
)
facet_cumulative += plots_per_var[var_name]
viz_dict[var_name] = xr.Dataset(
{
"plot": (
dims,
flat_ax_ary[col_slice].reshape(plots_raw_shape),
),
"row": (
dims,
flat_row_id[col_slice].reshape(plots_raw_shape),
),
"col": (
dims,
flat_col_id[col_slice].reshape(plots_raw_shape),
),
}
)
viz_dt = DataTree.from_dict(viz_dict)
return cls(data, viz_dt, backend=backend, **kwargs)
@classmethod
def grid(
cls,
data,
cols=None,
rows=None,
backend="matplotlib",
plot_grid_kws=None,
**kwargs,
):
if cols is None:
cols = []
if rows is None:
rows = []
if plot_grid_kws is None:
plot_grid_kws = {}
repeated_dims = [col for col in cols if col in rows]
if repeated_dims:
raise ValueError("The same dimension can't be used for both cols and rows.")
n_cols, cols_per_var = _process_facet_dims(data, cols)
n_rows, rows_per_var = _process_facet_dims(data, rows)
n_plots = n_cols * n_rows
plot_bknd = import_module(f".backend.{backend}", package="xrtist")
fig, ax_ary = plot_bknd.create_plotting_grid(
n_plots, n_rows, n_cols, squeeze=False, **plot_grid_kws
)
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
viz_dict = {}
if "__variable__" not in cols and "__variable__" not in rows:
dims = tuple((*rows, *cols)) # use provided dim orders, not existing ones
plots_raw_shape = [data.sizes[dim] for dim in dims]
viz_dict["/"] = xr.Dataset(
{
"chart": fig,
"plot": (dims, ax_ary.flatten().reshape(plots_raw_shape)),
"row": (dims, row_id.flatten().reshape(plots_raw_shape)),
"col": (dims, col_id.flatten().reshape(plots_raw_shape)),
},
coords={dim: data[dim] for dim in dims},
)
else:
viz_dict["/"] = xr.Dataset({"chart": xr.DataArray(fig)})
all_dims = tuple((*rows, *cols)) # use provided dim orders, not existing ones
facet_cumulative = 0
for var_name, da in data.items():
dims = [dim for dim in all_dims if dim in da.dims]
plots_raw_shape = [data.sizes[dim] for dim in dims]
row_slice = (
slice(None, None)
if var_name not in rows_per_var
else slice(facet_cumulative, facet_cumulative + rows_per_var[var_name])
)
col_slice = (
slice(None, None)
if var_name not in cols_per_var
else slice(facet_cumulative, facet_cumulative + cols_per_var[var_name])
)
if rows_per_var:
facet_cumulative += rows_per_var[var_name]
else:
facet_cumulative += cols_per_var[var_name]
viz_dict[var_name] = xr.Dataset(
{
"plot": (
dims,
ax_ary[row_slice, col_slice].flatten().reshape(plots_raw_shape),
),
"row": (
dims,
row_id[row_slice, col_slice].flatten().reshape(plots_raw_shape),
),
"col": (
dims,
col_id[row_slice, col_slice].flatten().reshape(plots_raw_shape),
),
}
)
viz_dt = DataTree.from_dict(viz_dict)
return cls(data, viz_dt, backend=backend, **kwargs)
def _update_aes(self, ignore_aes, coords):
aes = [aes_key for aes_key in self._aes.keys() if aes_key not in ignore_aes]
aes_dims = [dim for aes_key in aes for dim in self._aes[aes_key]]
all_loop_dims = self.base_loop_dims.union(aes_dims).difference(coords.keys())
return aes, all_loop_dims
def plot_iterator(self, ignore_aes=frozenset()):
_, all_loop_dims = self._update_aes(ignore_aes)
plotters = xarray_sel_iter(
self.data, skip_dims={dim for dim in self.data.dims if dim not in all_loop_dims}
)
for var_name, sel, isel in plotters:
yield var_name, sel, isel
def map(
self,
fun,
fun_label=None,
*,
coords=None,
ignore_aes=frozenset(),
preprocessed=False,
subset_info=False,
store_artist=True,
artist_dims=None,
**kwargs,
):
if coords is None:
coords = {}
if self.dt is None:
self.generate_aes_dt(self._aes, **self._kwargs)
if artist_dims is None:
artist_dims = {}
if fun_label is None:
fun_label = fun.__name__
data = self.data.sel(coords)
aes, all_loop_dims = self._update_aes(ignore_aes, coords)
plotters = xarray_sel_iter(
data, skip_dims={dim for dim in data.dims if dim not in all_loop_dims}
)
if store_artist:
for var_name, da in data.items():
if var_name not in self.viz.children:
DataTree(name=var_name, parent=self.viz)
inherited_dims = [dim for dim in da.dims if dim in all_loop_dims]
artist_shape = [da.sizes[dim] for dim in inherited_dims] + list(
artist_dims.values()
)
all_artist_dims = inherited_dims + list(artist_dims.keys())
self.viz[var_name][fun_label] = xr.DataArray(
np.empty(artist_shape, dtype=object),
dims=all_artist_dims,
coords={dim: data[dim] for dim in inherited_dims},
)
for var_name, sel, isel in plotters:
da = data[var_name].sel(sel)
sel_plus = {**sel, **coords}
target = subset_ds(self.get_viz(var_name), "plot", sel_plus)
aes_kwargs = {}
for aes_key in aes:
aes_kwargs[aes_key] = subset_ds(self.dt[var_name], aes_key, sel_plus)
fun_kwargs = {**aes_kwargs, **kwargs}
fun_kwargs["backend"] = self.backend
if preprocessed:
if self.preprocessed_data is None:
raise ValueError(
"You must manually set the `preprocessed_data` to use preprocessed=True"
)
pre_da = self.preprocessed_data.sel(sel_subset(sel, self.preprocessed_data.dims))
fun_kwargs["preprocessed_data"] = pre_da
if subset_info:
fun_kwargs = {**fun_kwargs, "var_name": var_name, "sel": sel, "isel": isel}
aux_artist = fun(da, target=target, **fun_kwargs)
if store_artist:
self.viz[var_name][fun_label].loc[sel] = aux_artist
def add_legend(self, aes, artist, **kwargs):
pass