bnelearn.util.tensor_util module

This module implements util functions for PyTorch tensor operations.

bnelearn.util.tensor_util.apply_with_dynamic_mini_batching(function: callable, args: Tensor, mute: bool = False) List[Tensor][source]

Apply the function function batch wise to the tensor argument args with error handling for CUDA Out-Of-Memory problems. Starting with the full batch, this method will cut the batch size in half until the operation succeeds (or a non-CUDA-OOM error occurs).

NOTE: The automatic error handling applies to CUDA memory limits only. This function does not provide any benefits when processing on CPU with regular RAM.

Args:

function :callable: function to be evaluated. args :torch.Tensor: pytorch.tensor arguments passed to function. mute :bool: Suppress console output.

Returns:

function evaluated at args.

bnelearn.util.tensor_util.batched_index_select(input: Tensor, dim: int, index: Tensor) Tensor[source]

Extends the torch index_select function to be used for multiple batches at once.

This code is borrowed from https://discuss.pytorch.org/t/batched-index-select/9115/11.

author:

dashesy

args:

input :torch.Tensor: Tensor which is to be indexed dim :int: Dimension index: :torch.Tensor: Index tensor which provides the selecting and ordering.

returns:

Indexed tensor :torch.Tensor: