Source code for chi.plots._sampling

#
# This file is part of the chi repository
# (https://github.com/DavAug/chi/) which is released under the
# BSD 3-clause license. See accompanying LICENSE.md for copyright notice and
# full license details.
#

import copy

import pints
import plotly.colors
import plotly.graph_objects as go
import xarray as xr

from chi import plots


[docs] class MarginalPosteriorPlot(plots.MultiSubplotFigure): """ A figure class that visualises the marginal posterior probability for each parameter across individuals. One figure is generated for each parameter, which contains a marginal histogram of the sampled parameter values for each individual. The estimates for each indiviudal are plotted next to each other. This figure can be used to assess the convergence of the sampling method, as well as the variation of parameter estimates across individuals. Extends :class:`MultiFigure`. """ def __init__(self): super(MarginalPosteriorPlot, self).__init__() def _add_histogram_plots(self, fig_id, data, colors): """ Adds histogram plots of the parameter samples for each individual to the figure. One figure contains only the histograms of one parameter. """ # Get number of colours n_colors = len(colors) # Check whether data are samples of a population parameter or # contains samples for many individuals is_population = False try: # Get ids of individuals ids = data.individual.values except AttributeError: is_population = True if is_population is True: # Compute diagnostics samples = data.values diagnostics = self._compute_diagnostics(samples) # Add trace color = colors[0] index = 0 individual = 'Population' samples = samples.flatten() self._add_trace( fig_id, index, individual, samples, diagnostics, color) return None # Add trace for each individual for index, individual in enumerate(ids): # Get data for individual samples = data.sel(individual=individual).dropna(dim='draw') # Compute diagnostics samples = samples.values diagnostics = self._compute_diagnostics(samples) # Add trace samples = samples.flatten() color = colors[index % n_colors] self._add_trace( fig_id, index, individual, samples, diagnostics, color) def _add_trace( self, fig_id, index, individual, samples, diagnostics, color): """ Adds a histogram of an indiviudals samples to a figure. """ # Get figure fig = self._figs[fig_id] # Add trace rhat, = diagnostics fig.add_trace( go.Histogram( y=samples, name='%s' % str(individual), hovertemplate=( 'Sample: %{y:.2f}<br>' + 'Rhat: %.02f<br>' % rhat), visible=True, marker=dict(color=color), opacity=0.8), row=1, col=index+1) # Turn off xaxis ticks fig.update_xaxes( tickvals=[], row=1, col=index+1) # Set x axis title to individual fig.update_xaxes( title_text=str(individual), row=1, col=index+1) def _compute_diagnostics(self, data): """ Computes and returns convergence metrics. - Rhat """ # Compute rhat rhat = pints.rhat(chains=data) # Collect diagnostics diagnostics = [rhat] return diagnostics
[docs] def add_data(self, data): """ Adds marginal histograms of the samples across runs to the figure. The histograms of population parameters are visualised in separate figures, while the individual parameters for one parameter type are grouped together. :param data: A :class:`xarray.Dataset` with the posterior samples. :type data: xarray.Dataset """ # Check input format if not isinstance(data, xr.Dataset): raise TypeError( 'The data has to be a xarray.Dataset.') dims = sorted(list(data.dims)) expected_dims = ['chain', 'draw', 'individual'] if (len(dims) == 2): expected_dims = ['chain', 'draw'] for dim in expected_dims: if dim not in dims: raise ValueError( 'The data must have the dimensions ' '(chain, draw, individual). The current dimensions are <' + str(dims) + '>.') # Get a colours colors = plotly.colors.qualitative.Plotly # Create a template figure (assigns it to self._fig) try: n_ids = len(data.individual) except AttributeError: # If data does not have individual attribute, all parameters # are population parameters n_ids = 1 self._create_template_figure( rows=1, cols=n_ids, x_title='Normalised counts', spacing=0.01) # Create one figure for each parameter figs = [] parameters = list(data.data_vars.keys()) for parameter in parameters: # Check that parameter has as many ids as columns in template # figure try: number_ids = len(data[parameter].individual) except AttributeError: # If parameter does not have individual attribute, it's a # population parameter number_ids = 1 if number_ids != n_ids: # Overwrite old n_ids n_ids = number_ids # Create a new template self._create_template_figure( rows=1, cols=n_ids, x_title='Normalised counts', spacing=0.01) # Append a copy of the template figure to all figures figs.append(copy.copy(self._fig)) self._figs = figs # Add samples to parameter figures for index, parameter in enumerate(parameters): # Set y label of plot to parameter name self._figs[index].update_yaxes( title_text=parameter, row=1, col=1) # Get estimates for this parameter samples = data[parameter] # Add marginal histograms for all individuals self._add_histogram_plots(index, samples, colors)