Source code for globalemu.plotter

"""

This function can be used to assess the accuracy of emulation of a test data
set given a trained model and produces a figure showing the
mean, 95th percentile and worst emulations. Examples of these figures can be
found in the `MNRAS preprint <https://arxiv.org/abs/2104.04336)>`__. The
figure will be saved in the provided ``'base_dir/'``.

"""

import numpy as np
from globalemu.losses import loss_functions
import matplotlib.pyplot as plt


[docs]class signal_plot(): r""" The class can be initialised with the following kwargs and the following code .. code:: python plotter = signal_plot(parameters, labels, loss_type, predictor, base_dir, **kwargs) **Parameters:** parameters: **list or np.array** | The astrophysical parameters corresponding to the testing data. labels: **list or np.array** | The signals, corresponding to the input parameters, that we want to predict and subsequently plot the mean, 95th percentile and worst emulations of. loss_type: ** str or function** | The metric by which we want to assess the accuracy of emulation. The built in loss functions can be accessed by setting this variable to 'rmse', 'mse' or 'GEMLoss'. Alternatively, a user defined callable function that takes in the labels and signals can also be provided. predictor: ** globalemu.eval object ** | An instance of the globalemu eval class that will be used to make predictions of the labels from the input parameters. base_dir: **string / default: 'model_dir/'** | The ``base_dir`` is where the signal plot will be saved. **kwargs:** rtol: **int or float / default: 1e-2** | The relative accuracy with which the function finds a signal with a loss equal to the mean loss for all predictions. atol: **int or float / default: 1e-2** | The absolute accuracy with which the function finds a signal with a loss equal to the mean loss for all predictions. figsizex: **int or float / default: 5** | The of the figure along the x axis to be passed to plt.subplots(). figsizey: **int or float / default: 10** | The of the figure along the y axis to be passed to plt.subplots(). xHI: **Bool / default: False** | If True then ``globalemu`` will act as if it is evaluating a neutral fraction history emulator. loss_label: **string/ default: 'Loss = {:.3f}'** | This kwarg can be used to adjust the loss labels in the plot legends. For example if we wanted precision in the 4th decimal place we can set ``loss_label= 'Loss = {:.4f}'``. Equally if we wanted to change the name of the loss and add in units we can have ``loss_label= 'RMSE = {:.3f} mK'``. """ def __init__(self, parameters, labels, loss_type, predictor, base_dir, **kwargs): for key, values in kwargs.items(): if key not in set( ['xHI', 'rtol', 'atol', 'figsizex', 'figsizey', 'loss_label']): raise KeyError("Unexpected keyword argument in evaluate()") self.rtol = kwargs.pop('rtol', 1e-2) self.atol = kwargs.pop('atol', 1e-2) self.figsizex = kwargs.pop('figsizex', 5) self.figsizey = kwargs.pop('figsizey', 10) self.loss_label = kwargs.pop('loss_label', 'Loss = {:.3f}') if type(self.loss_label) is not str: raise TypeError("'loss_label' must be a string.") float_kwargs = [self.rtol, self.atol, self.figsizex, self.figsizey] float_kwarg_str = ['rtol', 'atol', 'figsizex', 'figsizey'] for i in range(len(float_kwargs)): if type(float_kwargs[i]) not in set([float, int]): raise TypeError("'" + float_kwarg_str[i] + "' must be an integer or a float.") self.parameters = parameters if type(self.parameters) not in set([np.ndarray, list]): raise TypeError("'parameters' must be a list or np.array.") self.labels = labels if type(self.labels) not in set([np.ndarray, list]): raise TypeError("'labels' must be a list or np.array.") self.loss_type = loss_type self.base_dir = base_dir if self.loss_type not in set(['rmse', 'mse', 'GEMLoss']): if not callable(self.loss_type): raise TypeError("'loss_type' must be a string from the " + "predefined set (see documentaiton) or a " + "user defined function.") if type(self.base_dir) is not str: raise TypeError("'base_dir' must be a string.") elif self.base_dir.endswith('/') is False: raise TypeError("'base_dir' must end with '/'.") self.predictor = predictor if not callable(predictor): raise TypeError("'predictor' should be an instance of " + "globalemu.eval.") self.xHI = kwargs.pop('xHI', False) if type(self.xHI) is not bool: raise TypeError("'xHI' should be either True or False.") signal, z = self.predictor(self.parameters) loss = [] for i in range(len(signal)): if type(self.loss_type) is not str: loss.append(self.loss_type(self.labels[i], signal[i])) else: lf = loss_functions(self.labels[i], signal[i]) if self.loss_type == 'rmse': loss.append(lf.rmse().numpy()) elif self.loss_type == 'mse': loss.append(lf.mse()) elif self.loss_type == 'GEMLoss': loss.append(lf.GEMLoss()) loss = np.array(loss) mean_label = self.labels[ np.where( np.isclose( loss, loss.mean(), rtol=self.rtol, atol=self.atol))[0][0], :] mean_pred = signal[ np.where( np.isclose( loss, loss.mean(), rtol=self.rtol, atol=self.atol))[0][0], :] worst_label = self.labels[np.where(loss == loss.max())[0][0], :] worst_pred = signal[np.where(loss == loss.max())[0][0], :] args = np.argsort(loss) sorted_loss = loss[args] sorted_labels = self.labels[args] sorted_signals = signal[args] idx = int(len(sorted_loss)/100*95) limit95 = sorted_loss[idx] limit_label = sorted_labels[idx, :] limit_pred = sorted_signals[idx, :] fig, axes = plt.subplots(3, 1, figsize=(self.figsizex, self.figsizey), sharex=True) axes[0].plot(z, mean_label, label='True Signal') axes[0].plot(z, mean_pred, label=self.loss_label.format( loss[ np.where(np.isclose(loss, loss.mean(), rtol=self.rtol, atol=self.atol))[0][0]])) axes[1].plot(z, limit_label, label='True Signal') axes[1].plot(z, limit_pred, label=self.loss_label.format(limit95)) axes[2].plot(z, worst_label, label='True Signal') axes[2].plot(z, worst_pred, label=self.loss_label.format(loss.max())) axes[0].legend(title='Mean:') axes[1].legend(title='95%:') axes[2].legend(title='Worst:') if self.xHI is False: for i in range(len(axes)): axes[i].set_ylabel(r'$T_{21}$ [mk]') else: for i in range(len(axes)): axes[i].set_ylabel(r'$x_{HI}$') fig.add_subplot(111, frame_on=False) plt.tick_params(bottom=False, left=False, labelcolor='none') plt.xlabel('$z$') plt.tight_layout() plt.subplots_adjust(hspace=0, wspace=0) plt.savefig(self.base_dir + 'eval_plot.pdf') plt.close()