Who said looking at traces wasn't fun? This post aims to be an extension to the API section on Therefore, the most common arguments such as But what about having both rug and divergences at the same time? Fear not, ArviZ automatically modifies the default for divergences from lines : list of tuple of (str, dict, array_like), optional
List of It is possible that the first thought after reading this line is similar to "What is with this weird format?" Well, this format is actually the stardard way ArviZ uses to iterate over This section will be a little different from the other ones, and will focus on boosting Let's see what Now that we know about ArviZ in depth: plot_trace
Introductionplot_trace
is one of the most common plots to assess the convergence of MCMC runs, therefore, it is also one of the most used ArviZ functions. plot_trace
has a lot of parameters that allow creating highly customizable plots, but they may not be straightforward to use. There are many reasons that can explain this convolutedness of the arguments and their format, there is no clear culprit: ArviZ has to integrate with several libraries such as xarray and matplotlib which provide amazing features and customization power, and we'd like to allow ArviZ users to access all these features. However, we also aim to keep ArviZ usage simple and with sensible defaults; plot_xyz(idata)
should generate acceptable results in most situations.plot_trace
, focusing mostly on arguments where examples may be lacking and arguments that appear often in questions posted to ArviZ issues.var_names
will not be covered, and for arguments that I do not remeber appearing in issues or generating confusion only some examples will be shown without an in depth description.import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
# html render is not correctly rendered in blog,
# comment the line below if in jupyter
xr.set_options(display_style="text")
rng = np.random.default_rng()
az.style.use("arviz-darkgrid")
idata_centered = az.load_arviz_data("centered_eight")
idata = az.load_arviz_data("rugby")
The kind
argument
az.plot_trace
generates two columns. The left one calls plot_dist
to plot KDE/Histogram of the data, and the right column can contain either the trace itself (which gives the name to the plot) or a rank plot for which two visualizations are available. Rank plots are an alternative to trace plots, see https://arxiv.org/abs/1903.08008 for more details.fig, axes = plt.subplots(3,2, figsize=(12,6))
for i, kind in enumerate(("trace", "rank_bars", "rank_vlines")):
az.plot_trace(idata, var_names="home", kind=kind, ax=axes[i,:]);
fig.tight_layout()
az.plot_trace(idata_centered, var_names="tau");
az.plot_trace(idata_centered, var_names="tau", divergences=None);
ax = az.plot_trace(idata, var_names="home", rug=True, rug_kwargs={"alpha": .4})
bottom
to top
to prevent rug and divergences from overlapping:az.plot_trace(idata_centered, var_names="mu", rug=True);
(var_name, {‘coord’: selection}, [line, positions])
to be overplotted as vertical lines on the density and horizontal lines on the trace.xarray.Dataset
objects because it contains all the info about the variable and the selected coordinates as well as the values themselves. The main helper function that handles this is arviz.plots.plot_utils.xarray_var_iter
.plot_trace
capabilities with internal ArviZ functions. You may want to skip to the section altogether of go straigh to the end.xarray_var_iter
does with a simple dataset. We will create a dataset with two variables: a
will be a 2x3 matrix and b
will be a scalar. In addition, the dimensions of a
will be labeled.ds = xr.Dataset({
"a": (("pos", "direction"), rng.normal(size=(2,3))),
"b": 12,
"pos": ["top", "bottom"],
"direction": ["x", "y", "z"]
})
ds
from arviz.plots.plot_utils import xarray_var_iter
for var_name, sel, values in xarray_var_iter(ds):
print(var_name, sel, values)
xarray_var_iter
has iterated over every single scalar value without loosing track of where did every value come from. We can also modify the behaviour to skip some dimensions (i.e. in ArviZ we generally iterate over data dimensions and skip chain
and draw
dims).for var_name, sel, values in xarray_var_iter(ds, skip_dims={"direction"}):
print(var_name, sel, values)
xarray_var_iter
and what it does, we can use it to generate a list in the required format directly from xarray objects. Let's say for example we were interested in plotting the mean as a line in the trace plot:var_names = ["home", "atts"]
lines = list(xarray_var_iter(idata.posterior[var_names].mean(dim=("chain", "draw"))))
az.plot_trace(idata, var_names=var_names, lines=lines);