bnelearn.util.tensor_util module¶
This module implements util functions for PyTorch tensor operations.
- class bnelearn.util.tensor_util.GaussLayer(**kwargs)[source]¶
Bases:
Module
Custom layer for normally distributed predictions (non-negative).
Has no trainable parameters.
- forward(x, deterministic=False, pretrain=False)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class bnelearn.util.tensor_util.UniformLayer(**kwargs)[source]¶
Bases:
Module
Custom layer for predictions following a uniform distribution.
Has no trainable parameters.
- forward(x, deterministic=False, pretrain=False)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- bnelearn.util.tensor_util.apply_average_dynamic_mini_batching(function: callable, batch_size: int, shape, device) List[Tensor] [source]¶
- bnelearn.util.tensor_util.apply_with_dynamic_mini_batching(function: callable, args: Tensor, mute: bool = False) List[Tensor] [source]¶
Apply the function function batch wise to the tensor argument args with error handling for CUDA Out-Of-Memory problems. Starting with the full batch, this method will cut the batch size in half until the operation succeeds (or a non-CUDA-OOM error occurs).
NOTE: The automatic error handling applies to CUDA memory limits only. This function does not provide any benefits when processing on CPU with regular RAM.
- Args:
function :callable: function to be evaluated. args :torch.Tensor: pytorch.tensor arguments passed to function. mute :bool: Suppress console output.
- Returns:
function evaluated at args.
- bnelearn.util.tensor_util.batched_index_select(input: Tensor, dim: int, index: Tensor) Tensor [source]¶
Extends the torch
index_select
function to be used for multiple batches at once.This code is borrowed from https://discuss.pytorch.org/t/batched-index-select/9115/11.
- author:
dashesy
- args:
input :torch.Tensor: Tensor which is to be indexed dim :int: Dimension index: :torch.Tensor: Index tensor which provides the selecting and ordering.
- returns:
Indexed tensor :torch.Tensor: