{ "cells": [ { "cell_type": "markdown", "id": "48db7bce", "metadata": {}, "source": [ "# Contracting a large output lazily\n", "\n", "In this example we generate perform a contraction with an \n", "output that would be larger than can fit in memory. \n", "However we can still generate it in chunks which is \n", "sufficient to compute for example:\n", "\n", "$$\n", "S = - \\sum_{\\{a, b, c, \\ldots\\}} p_{a, b, c, \\ldots} \\log p_{a, b, c, \\ldots}\n", "$$" ] }, { "cell_type": "code", "execution_count": 1, "id": "9724fa7d-a982-477d-b96b-0170d64fa6a4", "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "import cotengra as ctg\n", "import quimb.tensor as qtn\n", "from autoray import do" ] }, { "cell_type": "markdown", "id": "8a706dc0-3bff-4585-a6c1-b4efb889d96f", "metadata": {}, "source": [ "Use quimb to make an example factor graph / probability distribution:" ] }, { "cell_type": "code", "execution_count": 2, "id": "7a84d1c9-a922-42ea-b507-a3f46b88ee2b", "metadata": {}, "outputs": [], "source": [ "htn = qtn.HTN3D_classical_ising_partition_function(\n", " 6, 6, 6, beta=0.3,\n", ")" ] }, { "cell_type": "markdown", "id": "8e3c972a", "metadata": {}, "source": [ "Here we optionally first convert the tensor network's data to cupy GPU arrays:" ] }, { "cell_type": "code", "execution_count": 3, "id": "7ac01629", "metadata": {}, "outputs": [], "source": [ "def to_backend(x):\n", " import cupy\n", "\n", " return cupy.asarray(x, dtype=\"float32\")\n", "\n", "htn.apply_to_arrays(to_backend)" ] }, { "cell_type": "markdown", "id": "110d1659-4e86-408c-98be-a32c4bd7e5a6", "metadata": {}, "source": [ "Select a subset of output variables (more than we can store the full tensor for!):" ] }, { "cell_type": "code", "execution_count": 4, "id": "80b1d3f3-2d21-4594-885c-b08088d26e03", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "36" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output_inds = tuple(\n", " f\"s{i},{j},{k}\"\n", " for i in range(4)\n", " for j in range(3)\n", " for k in range(3)\n", ")\n", "len(output_inds)" ] }, { "cell_type": "code", "execution_count": 5, "id": "0de83219-7eea-4366-a2e0-ab1367970380", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-07-29T15:45:42.724610\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "htn.draw(highlight_inds=output_inds)" ] }, { "cell_type": "code", "execution_count": 6, "id": "4400dfe4-115f-4d58-9b1d-c12213396f91", "metadata": {}, "outputs": [], "source": [ "opt = ctg.ReusableHyperOptimizer(\n", " minimize='combo',\n", " # here we put the actual amount of storage we are limited to\n", " slicing_reconf_opts={'target_size': 2**28},\n", " # the amount of time we want to spend searching \n", " # given we can compute at approximately 1e10 ops / sec\n", " max_time=\"rate:1e11\",\n", " progbar=True,\n", ")" ] }, { "cell_type": "markdown", "id": "dd3daa4d-de8a-4963-bde8-a4760279199c", "metadata": {}, "source": [ "First if we need to normalize we compute the full partition function:" ] }, { "cell_type": "code", "execution_count": 7, "id": "9463072f-4d1e-404f-9be5-334f717bf86a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "log2[SIZE]: 28.00 log10[FLOPs]: 12.80: 24%|█████████▉ | 31/128 [01:04<03:22, 2.09s/it]\n" ] } ], "source": [ "tree_Z = htn.contraction_tree(output_inds=(), optimize=opt)" ] }, { "cell_type": "markdown", "id": "6c0bd2b3", "metadata": {}, "source": [ "Since it could be a very large or small number we actively renormalize the\n", "tensors while contracting into a separate mantissa and exponent:" ] }, { "cell_type": "code", "execution_count": 8, "id": "953e8dd4-25e0-4f36-a39c-cbf87f0e3236", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████████████████████| 32/32 [00:07<00:00, 4.56it/s]\n" ] }, { "data": { "text/plain": [ "(array(4.2891397, dtype=float32), array(78.32228, dtype=float32))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Z_mantissa, Z_exponent = tree_Z.contract(\n", " htn.arrays, strip_exponent=True, progbar=True\n", ")\n", "Z_mantissa, Z_exponent" ] }, { "cell_type": "code", "execution_count": 9, "id": "896187e0-98de-4479-bd1f-77bb5c2d1d03", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
TensorNetwork(tensors=540, indices=216)
Tensor(shape=(2, 2), inds=[s0,0,0, s1,0,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,0, s0,1,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,0, s0,0,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,1, s1,0,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,1, s0,1,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,1, s0,0,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,2, s1,0,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,2, s0,1,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,2, s0,0,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,3, s1,0,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,3, s0,1,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,3, s0,0,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,4, s1,0,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,4, s0,1,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,4, s0,0,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,5, s1,0,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,0,5, s0,1,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,0, s1,1,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,0, s0,2,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,0, s0,1,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,1, s1,1,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,1, s0,2,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,1, s0,1,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,2, s1,1,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,2, s0,2,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,2, s0,1,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,3, s1,1,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,3, s0,2,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,3, s0,1,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,4, s1,1,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,4, s0,2,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,4, s0,1,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,5, s1,1,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,1,5, s0,2,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,0, s1,2,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,0, s0,3,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,0, s0,2,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,1, s1,2,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,1, s0,3,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,1, s0,2,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,2, s1,2,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,2, s0,3,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,2, s0,2,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,3, s1,2,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,3, s0,3,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,3, s0,2,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,4, s1,2,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,4, s0,3,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,4, s0,2,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,5, s1,2,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,2,5, s0,3,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,0, s1,3,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,0, s0,4,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,0, s0,3,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,1, s1,3,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,1, s0,4,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,1, s0,3,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,2, s1,3,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,2, s0,4,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,2, s0,3,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,3, s1,3,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,3, s0,4,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,3, s0,3,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,4, s1,3,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,4, s0,4,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,4, s0,3,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,5, s1,3,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,3,5, s0,4,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,0, s1,4,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,0, s0,5,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,0, s0,4,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,1, s1,4,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,1, s0,5,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,1, s0,4,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,2, s1,4,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,2, s0,5,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,2, s0,4,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,3, s1,4,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,3, s0,5,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,3, s0,4,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,4, s1,4,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,4, s0,5,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,4, s0,4,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,5, s1,4,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,4,5, s0,5,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,0, s1,5,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,0, s0,5,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,1, s1,5,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,1, s0,5,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,2, s1,5,2], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,2, s0,5,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,3, s1,5,3], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,3, s0,5,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,4, s1,5,4], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,4, s0,5,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s0,5,5, s1,5,5], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s1,0,0, s2,0,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s1,0,0, s1,1,0], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s1,0,0, s1,0,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)
Tensor(shape=(2, 2), inds=[s1,0,1, s2,0,1], tags={}),backend=cupy, dtype=float32, data=array([[0.9666009, 0.5304818],\n", " [0.5304818, 0.9666009]], dtype=float32)

...

" ], "text/plain": [ "TensorNetwork(tensors=540, indices=216)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we can perform a normalization by setting the negative exponent\n", "htn.exponent = -Z_exponent\n", "# this then spreads the exponent among all the actual tensors\n", "htn.equalize_norms_()" ] }, { "cell_type": "markdown", "id": "d12c5834-7010-4799-b839-803913618c3f", "metadata": {}, "source": [ "Then we can compute the output marginal contraction tree, which for factor\n", "graphs is just of matter of re-interprating certain indices as 'outputs':" ] }, { "cell_type": "code", "execution_count": 10, "id": "16db36bb-189b-4088-acf6-e70d85bff8cf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "log2[SIZE]: 28.00 log10[FLOPs]: 13.46: 100%|████████████████████████████████████████| 128/128 [01:39<00:00, 1.29it/s]\n" ] } ], "source": [ "tree_sub = htn.contraction_tree(output_inds=output_inds, optimize=opt)" ] }, { "cell_type": "markdown", "id": "122e1cf7-9977-460a-96ca-df0c2ce6b489", "metadata": {}, "source": [ "the output tensor is larger than our sliced size so we generate the output\n", "chunks lazily, which we can process one by one:" ] }, { "cell_type": "code", "execution_count": 11, "id": "d899a845", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████████████████████████████| 256/256 [02:15<00:00, 1.89it/s]\n" ] } ], "source": [ "S = sum(\n", " # using autoray handles numpy/cupy/torch/jax etc.\n", " -do(\"sum\", p_chunk * do(\"log\", p_chunk))\n", " for p_chunk in tree_sub.gen_output_chunks(\n", " htn.arrays, progbar=True,\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "id": "cfc90d5a-93ed-4816-8400-cb92190e29b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(56.018044, dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "S" ] } ], "metadata": { "kernelspec": { "display_name": "torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }