bnelearn.util.distribution_util module

Some utilities to work with torch.Distributions.

bnelearn.util.distribution_util.copy_dist_to_device(dist, device)[source]

A quick an dirty workaround to move torch.Distributions from one device to another.

To do so, we return a copy of the original distribution with all its tensor-valued members moved to the desired device.

Note that this will only work for the most basic distributions and will likely fail for complex or composed distribution objects. See https://github.com/pytorch/pytorch/issues/7795 for details.