Source code for torchquad.plots.plot_convergence

import matplotlib.pyplot as plt
import numpy as np


[docs] def plot_convergence(evals, fvals, ground_truth, labels, dpi=150): """Plots errors vs. function evaluations (fevals) and shows the convergence rate. Args: evals (list of np.array): Number of evaluations, for each method a np.array of ints. fvals (list of np.array): Function values for evals. ground_truth (np.array): Ground truth values. labels (list): Method names. dpi (int, optional): Plot dpi. Defaults to 150. """ plt.figure(dpi=dpi) for evals_item, f_item, label in zip(evals, fvals, labels): evals_item = np.array(evals_item) abs_err = np.abs(np.asarray(f_item) - np.asarray(ground_truth)) abs_err_delta = np.mean(np.abs((abs_err[:-1]) / (abs_err[1:] + 1e-16))) label = label + "\nConvergence Rate: " + str.format("{:.2e}", abs_err_delta) plt.semilogy(evals_item, abs_err, label=label) plt.legend(fontsize=6) plt.xlabel("# of function evaluations") plt.ylabel("Absolute error")