Source code for torchquad.utils.set_precision

from loguru import logger
import os
import sys


def _get_precision(backend):
    """Get the configured default precision for NumPy or Tensorflow.

    Args:
        backend ("numpy" or "tensorflow"): Numerical backend

    Returns:
        "float32", "float64" or None: Default floating point precision
    """
    return os.environ.get(f"TORCHQUAD_DTYPE_{backend.upper()}", None)


[docs] def set_precision(data_type="float32", backend="torch"): """Set the default precision for floating-point numbers for the given numerical backend. Call before declaring your variables. NumPy and doesn't have global dtypes: https://github.com/numpy/numpy/issues/6860 Therefore, torchquad sets the dtype argument for these it when initialising the integration domain. Args: data_type (str, optional): Data type to use, either "float32" or "float64". Defaults to "float32". backend (str, optional): Numerical backend for which the data type is changed. Defaults to "torch". """ # Backwards-compatibility: allow "float" and "double", optionally with # upper-case letters data_type = {"float": "float32", "double": "float64"}.get(data_type.lower(), data_type) if data_type not in ["float32", "float64"]: error_msg = f'Invalid data type "{data_type}". Only float32 and float64 are supported. Setting the data type to float32.' logger.error(error_msg) print(f"ERROR: {error_msg}", file=sys.stderr) data_type = "float32" if backend == "torch": import torch # Use new PyTorch 2.1+ API if available, fallback to legacy for older versions if hasattr(torch, "set_default_dtype"): # Modern PyTorch 2.1+ approach dtype_map = {"float32": torch.float32, "float64": torch.float64} torch.set_default_dtype(dtype_map[data_type]) # Only set default device to CUDA if CUDA was already initialized # (matching the old behavior more closely) cuda_enabled = torch.cuda.is_initialized() if cuda_enabled: torch.set_default_device("cuda") logger.info(f"Setting Torch's default dtype to {data_type} and device to CUDA.") else: logger.info(f"Setting Torch's default dtype to {data_type} (CPU).") else: # Legacy approach for older PyTorch versions cuda_enabled = torch.cuda.is_initialized() tensor_dtype, tensor_dtype_name = { ("float32", True): (torch.cuda.FloatTensor, "cuda.Float32"), ("float64", True): (torch.cuda.DoubleTensor, "cuda.Float64"), ("float32", False): (torch.FloatTensor, "Float32"), ("float64", False): (torch.DoubleTensor, "Float64"), }[(data_type, cuda_enabled)] cuda_enabled_info = "CUDA is initialized" if cuda_enabled else "CUDA not initialized" logger.info( f"Setting Torch's default tensor type to {tensor_dtype_name} ({cuda_enabled_info})." ) torch.set_default_tensor_type(tensor_dtype) elif backend == "jax": from jax import config config.update("jax_enable_x64", data_type == "float64") logger.info(f"JAX data type set to {data_type}") elif backend == "tensorflow": import tensorflow as tf # Set TensorFlow global precision tf.keras.backend.set_floatx(data_type) logger.info(f"TensorFlow default floatx set to {tf.keras.backend.floatx()}") elif backend == "numpy": # NumPy still lacks global dtype support os.environ["TORCHQUAD_DTYPE_NUMPY"] = data_type logger.info(f"NumPy default dtype set to {_get_precision('numpy')}") else: error_msg = f"Changing the data type is not supported for backend {backend}" logger.error(error_msg) print(f"ERROR: {error_msg}", file=sys.stderr)