Source code for coexist.plots

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : plots.py
# License: GNU v3.0
# Author : Andrei Leonard Nicusan <a.l.nicusan@bham.ac.uk>
# Date   : 30.06.2021


import  colorsys

import  numpy               as      np
from    scipy.interpolate   import  NearestNDInterpolator

import  matplotlib.colors   as      mc
import  plotly.express      as      px
import  plotly.graph_objs   as      go
from    plotly.subplots     import  make_subplots

import  coexist




class LightAdjuster:
    '''Darken / lighten a given colour. Instantiate the class with the colour
    wanted, then call the object with a float / list of floats - <1.0 darkens,
    1.0 does not change the colour, >1.0 lightens.
    '''

    def __init__(self, color):
        try:
            c = mc.cnames[color]
        except KeyError:
            c = color

        if c.startswith("rgb"):
            self.c = colorsys.rgb_to_hls(*(
                co / 255 for co in px.colors.unlabel_rgb(color)
            ))
        else:
            self.c = colorsys.rgb_to_hls(*mc.to_rgb(c))


    def adjust(self, amount):
        # Shortname
        c = self.c

        color_tuple = colorsys.hls_to_rgb(
            c[0], max(0, min(1, amount * c[1])), c[2]
        )

        return px.colors.label_rgb([co * 255 for co in color_tuple])


    def __call__(self, amounts):
        if not hasattr(amounts, "__iter__"):
            return self.adjust(amounts)

        return [self.adjust(am) for am in amounts]




def format_fig(fig, size=20, font="Computer Modern", template="plotly_white"):
    '''Format a Plotly figure to a consistent theme for the Nature
    Computational Science journal.'''

    # LaTeX font
    fig.update_layout(
        font_family = font,
        font_size = size,
        title_font_family = font,
        title_font_size = size,
    )

    for an in fig.layout.annotations:
        an["font"]["size"] = size

    fig.update_xaxes(title_font_family = font, title_font_size = size)
    fig.update_yaxes(title_font_family = font, title_font_size = size)
    fig.update_layout(template = template)


[docs]def access( access_data, select = lambda results: results[:, -1] < np.inf, epochs = ..., colors = px.colors.qualitative.Set1, overall = False, means = True, confidence = True, ): '''Create a Plotly figure showing the solutions tried, uncertainties and error values found in a `coexist.Access` run. Parameters ---------- access_data : coexist.AccessData or str An `AccessData` object containing all information about an ACCES run; you can initialise it with ``coexist.AccessData.read("folder_path")``. Alternatively, supply the ``folder_path`` directly. select : function, default lambda results: results[:, -1] < np.inf A filtering function used to plot only selected solutions tried, based on an input 2D table `results`, with columns formatted as [param1, param2, ..., param1_std, param2_std, ..., overall_std, error_value]. E.g. to only plot solutions with an error value smaller than 100: ``select = lambda results: results[:, -1] < 100``. epochs : int or iterable or Ellipsis, default Ellipsis The index or indices of the epochs to plot. An `int` signifies a single epoch, an iterable (list-like) signifies multiple epochs, while an Ellipsis (`...`) signifies all epochs. colors : list[str], default plotly.express.colors.qualitative.Set1 A list of colors used for each parameter plotted. overall : bool, default False If `True`, also plot the overall standard deviation progression; note that sometimes all parameters converge but the overall std-dev remains high. means : bool, default True If `True`, also plot the centre of the region explored by CMA-ES. confidence : bool, default True If `True`, also plot the standard deviation of each parameter as confidence intervals. Returns ------- plotly.graph_objs.Figure A Plotly figure containing subplots with the solutions tried. Call the `.show()` method to display it. Examples -------- If `coexist.Access(filepath, random_seed = 12345)` was run, the directory "access_seed12345" would have been created. Plot its results: >>> import coexist >>> data = coexist.AccessData.read("access_seed12345") >>> fig = coexist.plots.access(data) >>> fig.show() Or more tersely: >>> import coexist >>> coexist.plots.access("access_seed12345").show() Only plot solution combinations that yielded an error value < 100: >>> coexist.plots.access( >>> data, >>> select = lambda results: results[:, -1] < 100, >>> ).show() ''' # Type-checking inputs if not isinstance(access_data, coexist.AccessData): access_data = coexist.AccessData.read(access_data) # Check if sample_indices is an iterable collection (list-like) # otherwise just "iterate" over the single number or Ellipsis if not hasattr(epochs, "__iter__"): epochs = [epochs] # Extract data needed from `access_data` parameters = access_data.parameters # The data columns: [param1, param2, ..., error] results = access_data.results.to_numpy() epochs_unscaled = access_data.epochs.to_numpy() epochs_scaled = access_data.epochs_scaled.to_numpy() ns = access_data.population # The number of parameters num_parameters = len(parameters) num_errors = results.shape[1] - num_parameters names = parameters.index # Create a subplots grid ncols = int(np.ceil(np.sqrt(num_parameters + 1 + num_errors))) nrows = int(np.ceil((num_parameters + 1 + num_errors) / ncols)) fig = make_subplots( rows = nrows, cols = ncols, shared_xaxes = True, ) # Plot the parameter values checked per epoch num_epochs = access_data.num_epochs epochs_params = np.repeat(np.arange(num_epochs), ns) # Filter results plotted based on the error value and selected epochs selection = select(results) if epochs[0] is not Ellipsis: missing = set(range(num_epochs)) - set(epochs) for e in missing: selection[e * ns:e * ns + ns] = False # Compute relative error between 0 and 1 error = results[selection, -1] relative_error = (error - error.min()) / (error.max() - error.min()) # Plot solutions tried for each parameter for i in range(num_parameters): row = i // ncols + 1 col = i % ncols + 1 # Ensure the color_index does not go beyond the number of colours color_index = i while color_index >= len(colors): color_index -= len(colors) color = colors[color_index] adjuster = LightAdjuster(color) fig.add_trace( go.Scatter( name = names[i], x = epochs_params[selection], y = results[selection, i], mode = "markers", marker = dict( size = 8, opacity = 0.4, color = adjuster(1 - relative_error), ), showlegend = False, ), row = row, col = col, ) # Plot the unscaled means and std-dev as confidence intervals mu = epochs_unscaled[selection[::ns], i] if means: fig.add_trace( go.Scatter( x = epochs_params[selection][::ns], y = mu, mode = "lines", line = dict( width = 2, color = color, ), showlegend = False, ), row = row, col = col, ) # Add transparency to confidence interval colour if confidence: color_alpha = color.replace("rgb", "rgba").split(")")[0] + ",0.2)" std_x = epochs_params[selection][::ns] std_lo = mu - epochs_unscaled[selection[::ns], i + num_parameters] std_hi = mu + epochs_unscaled[selection[::ns], i + num_parameters] fig.add_trace( go.Scatter( x = std_x, y = std_lo, mode = "lines", line_width = 0, hoverinfo = "skip", showlegend = False, ), row = row, col = col, ) fig.add_trace( go.Scatter( x = std_x, y = std_hi, mode = "lines", line_width = 0, fill = 'tonexty', fillcolor = color_alpha, hoverinfo = "skip", showlegend = False, ), row = row, col = col, ) # Plot the scaled standard deviations after parameter values fig.add_trace( go.Scatter( name = names[i], x = epochs_params[selection][::ns], y = epochs_scaled[selection[::ns], i + num_parameters], mode = "lines", line = dict( color = color, ), ), row = num_parameters // ncols + 1, col = num_parameters % ncols + 1, ) # Plot the overall standard deviation if overall: fig.add_trace( go.Scatter( name = "Overall standard deviation", x = epochs_params[selection], y = results[selection, -2], mode = "lines", line = dict( color = "black", ) ), row = num_parameters // ncols + 1, col = num_parameters % ncols + 1, ) # Plot the error values for i in range(num_errors): row = (num_parameters + 1 + i) // ncols + 1 col = (num_parameters + 1 + i) % ncols + 1 fig.add_trace( go.Scatter( x = epochs_params[selection], y = results[selection, num_parameters + i], mode = "markers", marker = dict( size = 8, opacity = 0.4, color = results[selection, -1], colorscale = "cividis", ), showlegend = False, ), row = row, col = col, ) # Set graph ranges and axis labels for i in range(num_parameters): xaxis = "xaxis" if i == 0 else f"xaxis{i + 1}" yaxis = "yaxis" if i == 0 else f"yaxis{i + 1}" fig.layout[xaxis].update(title = "Epoch") fig.layout[yaxis].update(title = names[i]) # Set axis labels for the standard devation and error value subplots fig.layout[f"xaxis{num_parameters + 1}"].update(title = "Epoch") fig.layout[f"xaxis{num_parameters + 2}"].update(title = "Epoch") fig.layout[f"yaxis{num_parameters + 1}"].update( title = "Standard Deviation" ) # If the default error names are given, capitalise them for i in range(num_errors): title = access_data.results.columns[num_parameters + i] if title.startswith("error"): title = "E" + title[1:] if i == num_errors - 1: title = "Combined " + title fig.layout[f"yaxis{num_parameters + 2 + i}"].update(title = title) format_fig(fig) fig.update_layout(title = dict( text = "ACCES Convergence Plot", font_size = 25, )) return fig
[docs]def access2d( access_data, resolution = (500, 500), width = 0.2, select = lambda results: results[:, -1] < np.inf, epochs = ..., colorscale = "Blues_r", scaled = False, seeds = True, ): '''Create a Plotly figure showing 2D Voronoi diagram of the error values found in 2D slices of the parameters explored in a `coexist.Access` run. Parameters ---------- access_data : coexist.AccessData or str An `AccessData` object containing all information about an ACCES run; you can initialise it with ``coexist.AccessData.read("folder_path")``. Alternatively, supply the ``folder_path`` directly. resolution : 2-tuple, default (1000, 1000) The number of pixels in the heatmap / Voronoi diagram shown in the x- and y-dimensions. width : float, default 0.1 The width of the slices as a ratio of the parameter range. select : function, default lambda results: results[:, -1] < np.inf A filtering function used to plot only selected solutions tried, based on an input 2D table `results`, with columns formatted as [param1, param2, ..., param1_std, param2_std, ..., overall_std, error_value]. E.g. to only plot solutions with an error value smaller than 100: `select = lambda results: results[:, -1] < 100`. epochs : int or iterable or Ellipsis, default Ellipsis The index or indices of the epochs to plot. An `int` signifies a single epoch, an iterable (list-like) signifies multiple epochs, while an Ellipsis (`...`) signifies all epochs. colorscale : str, default "Blues_r" The colorscale used to colour-code the error value. For a list of possible colorscales, see `plotly.com/python/builtin-colorscales`. seeds : bool, default True If True, also plot the points representing parameter combinations tried. Returns ------- plotly.graph_objs.Figure A Plotly figure containing subplots with the solutions tried. Call the `.show()` method to display it. Examples -------- If `coexist.Access(filepath, random_seed = 12345)` was run, the directory "access_seed12345" would have been created. Plot its results: >>> import coexist >>> >>> data = coexist.AccessData.read("access_seed12345") >>> fig = coexist.plots.access2d(data) >>> fig.show() Or more tersely: >>> import coexist >>> coexist.plots.access2d("access_seed12345").show() Only plot the results from epochs 5, 6, 7: >>> coexist.plots.access2d(data, epochs = [4, 5, 6]).show() Only plot a slice through a 3D parameter space for `fp1` and `fp3`, with 0.4 < `fp2` < 0.6. >>> coexist.plot_access2d( >>> data, >>> columns = [0, 2] >>> select = lambda res: (res[:, 1] > 0.4) & (res[:, 1] < 0.6), >>> ).show() ''' # Type-checking inputs if not isinstance(access_data, coexist.AccessData): access_data = coexist.AccessData.read(access_data) # Check if sample_indices is an iterable collection (list-like) # otherwise just "iterate" over the single number or Ellipsis if not hasattr(epochs, "__iter__"): epochs = [epochs] # Extract data needed from `access_data` if scaled: results = access_data.results_scaled.to_numpy() epochs_raw = access_data.epochs_scaled.to_numpy() parameters = access_data.parameters_scaled else: results = access_data.results.to_numpy() epochs_raw = access_data.epochs.to_numpy() parameters = access_data.parameters num_params = len(parameters) names = parameters.index ns = access_data.population num_epochs = access_data.num_epochs # Create a subplots grid ncols = num_params - 1 nrows = num_params - 1 fig = make_subplots( rows = nrows, cols = ncols, shared_xaxes = True, shared_yaxes = True, horizontal_spacing = 0.1 / 2 / ncols, vertical_spacing = 0.2 / 2 / nrows, ) # Filter results plotted based on the error value and selected epochs selection = select(results) if epochs[0] is not Ellipsis: # Handle negative epoch indices epochs = (e if e >= 0 else e + num_epochs for e in epochs) # Set the booleans in `selection` to False for epochs not requested missing = set(range(num_epochs)) - set(epochs) for e in missing: selection[e * ns:e * ns + ns] = False results = results[selection] epochs_raw = epochs_raw[selection[::ns]] # Create a 2D map of pixels coloured by the closest measured point's error error = results[:, -1] error_bounds = [error.min(), error.max()] error_scaled = ( (error - error_bounds[0]) / (error_bounds[1] - error_bounds[0]) ) # Plot a lower triangular matrix without the diagonal, so for 3 parameters # => lower 2x2 triangle for i in range(1, num_params): for j in range(i): row = i col = j + 1 # Create a 2D error map with each pixel mappend to the closest # sample's error x = np.linspace(parameters["min"][j], parameters["max"][j], resolution[0]) y = np.linspace(parameters["min"][i], parameters["max"][i], resolution[1]) # Select parameter space slice of given `width` cond = np.full(len(results), True) others = set(range(num_params)) - {i, j} for o in others: param_values = results[:, o] param_range = parameters[["min", "max"]].iloc[o] param_range = param_range[1] - param_range[0] mean = epochs_raw[-1, o] cond = cond & ( (param_values > mean - 0.5 * width * param_range) & (param_values < mean + 0.5 * width * param_range) ) xx, yy = np.meshgrid(x, y) error_map = NearestNDInterpolator( results[cond][:, [j, i]], error[cond], )(xx, yy) # Plot the 2D error map fig.add_trace( go.Heatmap( x = x, y = y, z = error_map, zmin = error_bounds[0], zmax = error_bounds[1], colorscale = colorscale, colorbar_title = "Error", showscale = row == col == 1, showlegend = False, ), row = row, col = col, ) # Plot points tried. Use inverted colorscale for good contrast # (i.e. color = 1 / errors) if seeds: fig.add_trace( go.Scatter( x = results[cond, j], y = results[cond, i], mode = "markers", marker = dict( size = 1 + 10 * error_scaled, color = 1 / (error - error_bounds[0] + 1), colorscale = colorscale, colorbar_title = None, ), showlegend = False, ), row = row, col = col, ) # Set axis labels isub = (row - 1) * ncols + col xaxis = "xaxis" if isub == 1 else f"xaxis{isub}" yaxis = "yaxis" if isub == 1 else f"yaxis{isub}" # bounds = parameters[["min", "max"]] # fig.layout[xaxis].update(range = bounds.iloc[j].to_numpy()) # fig.layout[yaxis].update(range = bounds.iloc[i].to_numpy()) if col == 1: fig.layout[yaxis].update(title = names[i]) if row == nrows: fig.layout[xaxis].update(title = names[j]) format_fig(fig) fig.update_xaxes(showgrid = False, zeroline = False) fig.update_yaxes(showgrid = False, zeroline = False) prefix = "Scaled " if scaled else "" fig.update_layout(title = dict( text = ( f"{prefix}ACCES Voronoi Plot - Parameter Slices Width = " f"{width * 100:3.1f}% Data Range" ), font_size = 25, )) return fig
def surrogate2d( access_surrogate, resolution = (500, 500), width = 0.2, select = lambda results: results[:, -1] < np.inf, epochs = ..., colorscale = "Blues_r", scaled = False, seeds = True, ): '''Create a Plotly figure showing 2D Voronoi diagram of the error values found in 2D slices of the parameters explored in a `coexist.Access` run. Parameters ---------- access_data : coexist.AccessData or str An `AccessData` object containing all information about an ACCES run; you can initialise it with ``coexist.AccessData.read("folder_path")``. Alternatively, supply the ``folder_path`` directly. resolution : 2-tuple, default (1000, 1000) The number of pixels in the heatmap / Voronoi diagram shown in the x- and y-dimensions. width : float, default 0.1 The width of the slices as a ratio of the parameter range. select : function, default lambda results: results[:, -1] < np.inf A filtering function used to plot only selected solutions tried, based on an input 2D table `results`, with columns formatted as [param1, param2, ..., param1_std, param2_std, ..., overall_std, error_value]. E.g. to only plot solutions with an error value smaller than 100: `select = lambda results: results[:, -1] < 100`. epochs : int or iterable or Ellipsis, default Ellipsis The index or indices of the epochs to plot. An `int` signifies a single epoch, an iterable (list-like) signifies multiple epochs, while an Ellipsis (`...`) signifies all epochs. colorscale : str, default "Blues_r" The colorscale used to colour-code the error value. For a list of possible colorscales, see `plotly.com/python/builtin-colorscales`. seeds : bool, default True If True, also plot the points representing parameter combinations tried. Returns ------- plotly.graph_objs.Figure A Plotly figure containing subplots with the solutions tried. Call the `.show()` method to display it. Examples -------- If `coexist.Access(filepath, random_seed = 12345)` was run, the directory "access_seed12345" would have been created. Plot its results: >>> import coexist >>> >>> data = coexist.AccessData.read("access_seed12345") >>> fig = coexist.plots.access2d(data) >>> fig.show() Or more tersely: >>> import coexist >>> coexist.plots.access2d("access_seed12345").show() Only plot the results from epochs 5, 6, 7: >>> coexist.plots.access2d(data, epochs = [4, 5, 6]).show() Only plot a slice through a 3D parameter space for `fp1` and `fp3`, with 0.4 < `fp2` < 0.6. >>> coexist.plot_access2d( >>> data, >>> columns = [0, 2] >>> select = lambda res: (res[:, 1] > 0.4) & (res[:, 1] < 0.6), >>> ).show() ''' # Type-checking inputs if not isinstance(access_surrogate, coexist.AccessSurrogate): access_data = coexist.AccessData.read(access_data) # Check if sample_indices is an iterable collection (list-like) # otherwise just "iterate" over the single number or Ellipsis if not hasattr(epochs, "__iter__"): epochs = [epochs] # Extract data needed from `access_data` parameters = access_data.parameters num_params = len(parameters) names = parameters.index # The data columns: [param1, param2, ..., error] if scaled: results = access_data.results_scaled.to_numpy() epochs_raw = access_data.epochs_scaled.to_numpy() # Scale parameter values scaling = ( access_data.results.to_numpy() / access_data.results_scaled.to_numpy() ) scaling = np.mean(scaling[:, :-1], axis = 0) parameters = parameters.copy() parameters["min"] /= scaling parameters["max"] /= scaling else: results = access_data.results.to_numpy() epochs_raw = access_data.epochs.to_numpy() ns = access_data.population num_epochs = access_data.num_epochs # Create a subplots grid ncols = num_params - 1 nrows = num_params - 1 fig = make_subplots( rows = nrows, cols = ncols, shared_xaxes = True, shared_yaxes = True, horizontal_spacing = 0.1 / 2 / ncols, vertical_spacing = 0.2 / 2 / nrows, ) # Filter results plotted based on the error value and selected epochs selection = select(results) if epochs[0] is not Ellipsis: # Handle negative epoch indices epochs = (e if e >= 0 else e + num_epochs for e in epochs) # Set the booleans in `selection` to False for epochs not requested missing = set(range(num_epochs)) - set(epochs) for e in missing: selection[e * ns:e * ns + ns] = False results = results[selection] epochs_raw = epochs_raw[selection[::ns]] # Create a 2D map of pixels coloured by the closest measured point's error error = results[:, -1] error_bounds = [error.min(), error.max()] error_scaled = ( (error - error_bounds[0]) / (error_bounds[1] - error_bounds[0]) ) # Plot a lower triangular matrix without the diagonal, so for 3 parameters # => lower 2x2 triangle for i in range(1, num_params): for j in range(i): row = i col = j + 1 # Create a 2D error map with each pixel mappend to the closest # sample's error x = np.linspace(parameters["min"][j], parameters["max"][j], resolution[0]) y = np.linspace(parameters["min"][i], parameters["max"][i], resolution[1]) # Select parameter space slice of given `width` cond = np.full(len(results), True) others = set(range(num_params)) - {i, j} for o in others: param_values = results[:, o] param_range = parameters[["min", "max"]].iloc[o] param_range = param_range[1] - param_range[0] mean = epochs_raw[-1, o] cond = cond & ( (param_values > mean - 0.5 * width * param_range) & (param_values < mean + 0.5 * width * param_range) ) xx, yy = np.meshgrid(x, y) error_map = NearestNDInterpolator( results[cond][:, [j, i]], error[cond], )(xx, yy) # Plot the 2D error map fig.add_trace( go.Heatmap( x = x, y = y, z = error_map, zmin = error_bounds[0], zmax = error_bounds[1], colorscale = colorscale, colorbar_title = "Error", showscale = row == col == 1, showlegend = False, ), row = row, col = col, ) # Plot points tried. Use inverted colorscale for good contrast # (i.e. color = 1 / errors) if seeds: fig.add_trace( go.Scatter( x = results[cond, j], y = results[cond, i], mode = "markers", marker = dict( size = 1 + 10 * error_scaled, color = 1 / (error - error_bounds[0] + 1), colorscale = colorscale, colorbar_title = None, ), showlegend = False, ), row = row, col = col, ) # Set axis labels isub = (row - 1) * ncols + col xaxis = "xaxis" if isub == 1 else f"xaxis{isub}" yaxis = "yaxis" if isub == 1 else f"yaxis{isub}" # bounds = parameters[["min", "max"]] # fig.layout[xaxis].update(range = bounds.iloc[j].to_numpy()) # fig.layout[yaxis].update(range = bounds.iloc[i].to_numpy()) if col == 1: fig.layout[yaxis].update(title = names[i]) if row == nrows: fig.layout[xaxis].update(title = names[j]) format_fig(fig) # fig.update_xaxes(showgrid = False, zeroline = False) # fig.update_yaxes(showgrid = False, zeroline = False) prefix = "Scaled " if scaled else "" fig.update_layout(title = dict( text = ( f"{prefix}ACCES Voronoi Plot - Parameter Slices Width = " f"{width * 100:3.1f}% Data Range" ), font_size = 25, )) return fig