bnelearn.util.tensor_util module

This module implements util functions for PyTorch tensor operations.

class bnelearn.util.tensor_util.GaussLayer(**kwargs)[source]

Bases: Module

Custom layer for normally distributed predictions (non-negative).

Has no trainable parameters.

forward(x, deterministic=False, pretrain=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class bnelearn.util.tensor_util.UniformLayer(**kwargs)[source]

Bases: Module

Custom layer for predictions following a uniform distribution.

Has no trainable parameters.

forward(x, deterministic=False, pretrain=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
bnelearn.util.tensor_util.apply_average_dynamic_mini_batching(function: callable, batch_size: int, shape, device) List[Tensor][source]
bnelearn.util.tensor_util.apply_with_dynamic_mini_batching(function: callable, args: Tensor, mute: bool = False) List[Tensor][source]

Apply the function function batch wise to the tensor argument args with error handling for CUDA Out-Of-Memory problems. Starting with the full batch, this method will cut the batch size in half until the operation succeeds (or a non-CUDA-OOM error occurs).

NOTE: The automatic error handling applies to CUDA memory limits only. This function does not provide any benefits when processing on CPU with regular RAM.

Args:

function :callable: function to be evaluated. args :torch.Tensor: pytorch.tensor arguments passed to function. mute :bool: Suppress console output.

Returns:

function evaluated at args.

bnelearn.util.tensor_util.batched_index_select(input: Tensor, dim: int, index: Tensor) Tensor[source]

Extends the torch index_select function to be used for multiple batches at once.

This code is borrowed from https://discuss.pytorch.org/t/batched-index-select/9115/11.

author:

dashesy

args:

input :torch.Tensor: Tensor which is to be indexed dim :int: Dimension index: :torch.Tensor: Index tensor which provides the selecting and ordering.

returns:

Indexed tensor :torch.Tensor: