Contracting a large output lazily¶
In this example we generate perform a contraction with an output that would be larger than can fit in memory. However we can still generate it in chunks which is sufficient to compute for example:
%config InlineBackend.figure_formats = ['svg']
import cotengra as ctg
import quimb.tensor as qtn
from autoray import do
Use quimb to make an example factor graph / probability distribution:
htn = qtn.HTN3D_classical_ising_partition_function(
6,
6,
6,
beta=0.3,
)
Here we optionally first convert the tensor network’s data to cupy GPU arrays:
def to_backend(x):
import cupy
return cupy.asarray(x, dtype="float32")
htn.apply_to_arrays(to_backend)
Select a subset of output variables (more than we can store the full tensor for!):
output_inds = tuple(
f"s{i},{j},{k}" for i in range(4) for j in range(3) for k in range(3)
)
len(output_inds)
36
htn.draw(highlight_inds=output_inds)
opt = ctg.ReusableHyperOptimizer(
minimize="combo",
# here we put the actual amount of storage we are limited to
slicing_reconf_opts={"target_size": 2**28},
# the amount of time we want to spend searching
# given we can compute at approximately 1e10 ops / sec
max_time="rate:1e11",
progbar=True,
)
First if we need to normalize we compute the full partition function:
tree_Z = htn.contraction_tree(output_inds=(), optimize=opt)
log2[SIZE]: 28.00 log10[FLOPs]: 12.80: 24%|█████████▉ | 31/128 [01:04<03:22, 2.09s/it]
Since it could be a very large or small number we actively renormalize the tensors while contracting into a separate mantissa and exponent:
Z_mantissa, Z_exponent = tree_Z.contract(
htn.arrays, strip_exponent=True, progbar=True
)
Z_mantissa, Z_exponent
100%|█████████████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00, 4.56it/s]
(array(4.2891397, dtype=float32), array(78.32228, dtype=float32))
# we can perform a normalization by setting the negative exponent
htn.exponent = -Z_exponent
# this then spreads the exponent among all the actual tensors
htn.equalize_norms_()
TensorNetwork(tensors=540, indices=216)
Tensor(shape=(2, 2), inds=[s0,0,0, s1,0,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,0, s0,1,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,0, s0,0,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,1, s1,0,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,1, s0,1,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,1, s0,0,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,2, s1,0,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,2, s0,1,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,2, s0,0,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,3, s1,0,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,3, s0,1,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,3, s0,0,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,4, s1,0,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,4, s0,1,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,4, s0,0,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,5, s1,0,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,0,5, s0,1,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,0, s1,1,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,0, s0,2,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,0, s0,1,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,1, s1,1,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,1, s0,2,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,1, s0,1,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,2, s1,1,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,2, s0,2,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,2, s0,1,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,3, s1,1,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,3, s0,2,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,3, s0,1,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,4, s1,1,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,4, s0,2,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,4, s0,1,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,5, s1,1,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,1,5, s0,2,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,0, s1,2,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,0, s0,3,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,0, s0,2,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,1, s1,2,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,1, s0,3,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,1, s0,2,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,2, s1,2,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,2, s0,3,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,2, s0,2,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,3, s1,2,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,3, s0,3,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,3, s0,2,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,4, s1,2,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,4, s0,3,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,4, s0,2,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,5, s1,2,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,2,5, s0,3,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,0, s1,3,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,0, s0,4,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,0, s0,3,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,1, s1,3,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,1, s0,4,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,1, s0,3,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,2, s1,3,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,2, s0,4,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,2, s0,3,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,3, s1,3,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,3, s0,4,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,3, s0,3,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,4, s1,3,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,4, s0,4,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,4, s0,3,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,5, s1,3,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,3,5, s0,4,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,0, s1,4,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,0, s0,5,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,0, s0,4,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,1, s1,4,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,1, s0,5,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,1, s0,4,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,2, s1,4,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,2, s0,5,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,2, s0,4,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,3, s1,4,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,3, s0,5,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,3, s0,4,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,4, s1,4,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,4, s0,5,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,4, s0,4,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,5, s1,4,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,4,5, s0,5,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,0, s1,5,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,0, s0,5,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,1, s1,5,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,1, s0,5,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,2, s1,5,2], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,2, s0,5,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,3, s1,5,3], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,3, s0,5,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,4, s1,5,4], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,4, s0,5,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s0,5,5, s1,5,5], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s1,0,0, s2,0,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s1,0,0, s1,1,0], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s1,0,0, s1,0,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)Tensor(shape=(2, 2), inds=[s1,0,1, s2,0,1], tags={}),
backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818], [0.5304818, 0.9666009]], dtype=float32)...
Then we can compute the output marginal contraction tree, which for factor graphs is just of matter of re-interprating certain indices as ‘outputs’:
tree_sub = htn.contraction_tree(output_inds=output_inds, optimize=opt)
log2[SIZE]: 28.00 log10[FLOPs]: 13.46: 100%|████████████████████████████████████████| 128/128 [01:39<00:00, 1.29it/s]
the output tensor is larger than our sliced size so we generate the output chunks lazily, which we can process one by one:
S = sum(
# using autoray handles numpy/cupy/torch/jax etc.
-do("sum", p_chunk * do("log", p_chunk))
for p_chunk in tree_sub.gen_output_chunks(
htn.arrays,
progbar=True,
)
)
100%|███████████████████████████████████████████████████████████████████████████████| 256/256 [02:15<00:00, 1.89it/s]
S
array(56.018044, dtype=float32)