bnelearn.util.tensor_util module¶
This module implements util functions for PyTorch tensor operations.
- bnelearn.util.tensor_util.apply_with_dynamic_mini_batching(function: callable, args: Tensor, mute: bool = False) List[Tensor][source]¶
Apply the function function batch wise to the tensor argument args with error handling for CUDA Out-Of-Memory problems. Starting with the full batch, this method will cut the batch size in half until the operation succeeds (or a non-CUDA-OOM error occurs).
NOTE: The automatic error handling applies to CUDA memory limits only. This function does not provide any benefits when processing on CPU with regular RAM.
- Args:
function :callable: function to be evaluated. args :torch.Tensor: pytorch.tensor arguments passed to function. mute :bool: Suppress console output.
- Returns:
function evaluated at args.
- bnelearn.util.tensor_util.batched_index_select(input: Tensor, dim: int, index: Tensor) Tensor[source]¶
Extends the torch
index_selectfunction to be used for multiple batches at once.This code is borrowed from https://discuss.pytorch.org/t/batched-index-select/9115/11.
- author:
dashesy
- args:
input :torch.Tensor: Tensor which is to be indexed dim :int: Dimension index: :torch.Tensor: Index tensor which provides the selecting and ordering.
- returns:
Indexed tensor :torch.Tensor: