Source code for bnelearn.util.tensor_util

"""This module implements util functions for PyTorch tensor operations."""

import traceback
from tqdm import tqdm
from typing import List
from math import ceil
import torch

_CUDA_OOM_ERR_MSG_START = "CUDA out of memory. Tried to allocate"
ERR_MSG_OOM_SINGLE_BATCH = "Failed for good. Even a batch_size of 1 leads to OOM!"


[docs]def batched_index_select(input: torch.Tensor, dim: int, index: torch.Tensor) -> torch.Tensor: """Extends the torch ``index_select`` function to be used for multiple batches at once. This code is borrowed from https://discuss.pytorch.org/t/batched-index-select/9115/11. author: dashesy args: input :torch.Tensor: Tensor which is to be indexed dim :int: Dimension index: :torch.Tensor: Index tensor which provides the selecting and ordering. returns: Indexed tensor :torch.Tensor: """ for ii in range(1, len(input.shape)): if ii != dim: index = index.unsqueeze(ii) expanse = list(input.shape) expanse[0] = -1 expanse[dim] = -1 index = index.expand(expanse) return torch.gather(input, dim, index)
[docs]def apply_with_dynamic_mini_batching( function: callable, args: torch.Tensor, mute: bool=False, ) -> List[torch.Tensor]: """Apply the function `function` batch wise to the tensor argument `args` with error handling for CUDA Out-Of-Memory problems. Starting with the full batch, this method will cut the batch size in half until the operation succeeds (or a non-CUDA-OOM error occurs). NOTE: The automatic error handling applies to CUDA memory limits only. This function does not provide any benefits when processing on CPU with regular RAM. Args: function :callable: function to be evaluated. args :torch.Tensor: pytorch.tensor arguments passed to function. mute :bool: Suppress console output. Returns: function evaluated at args. """ batch_size = args.shape[0] output_sample = function(args[[0], ...]) n_outputs = len(output_sample) output_dtypes = [o.dtype for o in output_sample] output_shapes = [tuple(o.shape[1:]) for o in output_sample] output = [ torch.empty( (batch_size, *output_shapes[i]), dtype=output_dtypes[i], device=args.device ) for i in range(n_outputs) ] calculation_successful = False mini_batch_size = batch_size while not calculation_successful: try: if not mute: print(f"Trying {function} calculation with batch_size {mini_batch_size}...") # Split up arguments into smaller chunks of batch size `mini_batch_size` mini_args = args.split(mini_batch_size) # Iterate over chunks custom_range = enumerate(mini_args) if mute else tqdm(enumerate(mini_args), total=ceil(len(mini_args))) for i, mini_arg in custom_range: # Get the indices corresponding to this mini batch indices = slice(i*mini_batch_size, (i+1)*mini_batch_size) mini_output = function(mini_arg) for out_dim in range(n_outputs): output[out_dim][indices] = mini_output[out_dim] calculation_successful = True if not mute: print("\t ... success!") except RuntimeError as e: if not str(e).startswith(_CUDA_OOM_ERR_MSG_START): raise e if mini_batch_size <= 1: traceback.print_exc() # pylint: disable = raise-missing-from raise RuntimeError(ERR_MSG_OOM_SINGLE_BATCH) if not mute: print("\t... failed (OOM). Decreasing mini batch size.") mini_batch_size = int(mini_batch_size / 2) return output