Source code for torchquad.utils.enable_cuda

from loguru import logger

from .set_precision import set_precision


[docs] def enable_cuda(data_type="float32"): """This function sets torch's default device to CUDA if possible. Call before declaring any variables! The default precision can be set here initially, or using set_precision later. Args: data_type ("float32", "float64" or None, optional): Data type to use. If None, skip the call to set_precision. Defaults to "float32". """ import torch if torch.cuda.is_available(): logger.info("PyTorch VERSION: " + str(torch.__version__)) logger.info("CUDNN VERSION: " + str(torch.backends.cudnn.version())) logger.info("Number of CUDA Devices: " + str(torch.cuda.device_count())) logger.info("Active CUDA Device: GPU" + str(torch.cuda.current_device())) if data_type is not None: set_precision(data_type) else: logger.warning( "Error enabling CUDA. cuda.is_available() returned False. CPU will be used." )