{ "cells": [ { "cell_type": "markdown", "id": "43a0bfd5-762c-42c4-bf40-7c7cba1c33f4", "metadata": {}, "source": [ "# Tree Surgery\n", "\n", "The core data structure in [`cotengra`](cotengra) is the \n", "[`ContractionTree`](cotengra.core.ContractionTree), which as well as describing \n", "the tree generates all the required intermediate indices and equations etc. \n", "This page describes some aspects of the design and ways you can modify or \n", "construct trees yourself.\n", "\n", "First we just generate a small random contraction:" ] }, { "cell_type": "code", "execution_count": 1, "id": "a4b0dd86-073a-499e-8a20-2c2eb9833b71", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dn,bhl,afj,cejk,cdefglmno,gh,i,k,im,o->ba\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-02-26T16:10:39.573258\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.3, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "import cotengra as ctg # noqa\n", "\n", "# generate a random contraction\n", "inputs, output, shapes, size_dict = ctg.utils.rand_equation(10, 3, n_out=2, seed=4)\n", "print(ctg.utils.inputs_output_to_eq(inputs, output))\n", "\n", "# turn into a hypergraph in order to visualize\n", "hg = ctg.get_hypergraph(inputs, output, size_dict)\n", "hg.plot()" ] }, { "cell_type": "markdown", "id": "cffdc1f6-cbbb-4c81-b04f-8bc14698a2b5", "metadata": {}, "source": [ "## Design" ] }, { "cell_type": "markdown", "id": "e96566e9-3cd7-4a5b-81ab-df876ff62e3a", "metadata": {}, "source": [ "We can create an empty tree with the following:" ] }, { "cell_type": "code", "execution_count": 2, "id": "b7680a24-d233-4f53-a404-f9627fd99ba3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree = ctg.ContractionTree(inputs, output, size_dict)\n", "tree" ] }, { "cell_type": "markdown", "id": "3efc6c42-057f-4220-bbb1-8496b395438d", "metadata": {}, "source": [ "```{note}\n", "You can also initialize an empty tree with [`ContractionTree.from_eq`](cotengra.core.ContractionTree.from_eq) or construct a completed tree with [`ContractionTree.from_path`](cotengra.core.ContractionTree.from_path) and [`ContractionTree.from_info`](cotengra.core.ContractionTree.from_info). The [`search`](cotengra.HyperOptimizer.search) method of optimizers also directly returns a tree.\n", "```" ] }, { "cell_type": "markdown", "id": "f9986d72-400b-455c-b922-ef0bba49f1b5", "metadata": {}, "source": [ "The nodes of the tree are frozen sets of integers, describing groups of inputs which form intermediate tensors. Initially only the final output tensor (tree root node), and input tensors (tree leaf nodes) are known:" ] }, { "cell_type": "code", "execution_count": 3, "id": "33072c14-7eba-46d7-a95f-537b05b78270", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "frozenset({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})\n" ] } ], "source": [ "print(tree.root)" ] }, { "cell_type": "code", "execution_count": 4, "id": "0da716e0-a7b9-4a14-a3ad-dc012c5e5de7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "frozenset({0})\n", "frozenset({1})\n", "frozenset({2})\n", "frozenset({3})\n", "frozenset({4})\n", "frozenset({5})\n", "frozenset({6})\n", "frozenset({7})\n", "frozenset({8})\n", "frozenset({9})\n" ] } ], "source": [ "for node in tree.gen_leaves():\n", " print(node)" ] }, { "cell_type": "markdown", "id": "9ccb6fe0-76bd-4980-b7df-56f9b4ae2088", "metadata": {}, "source": [ "In order to complete the tree we need to find $N - 1$ contractions (branches) via either merges of parentless nodes (like the leaves above) or partitions of childless nodes (like the root above). [`contract_nodes`](cotengra.core.ContractionTree.contract_nodes) is the method used to form these merges for an arbitrary set of nodes. " ] }, { "cell_type": "markdown", "id": "6b4cbf1a-7f5e-479a-baef-c3b22be1281d", "metadata": {}, "source": [ "Here we merge two leaves to form a new intermediate (agglomeratively building the tree from the bottom):" ] }, { "cell_type": "code", "execution_count": 5, "id": "9323f9d6-efd9-46f3-98a9-125a4ef74a12", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "frozenset({0, 1})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "parent = tree.contract_nodes([frozenset([0]), frozenset([1])])\n", "parent" ] }, { "cell_type": "markdown", "id": "ad34b81a-5de3-4ae4-9a7c-ff7620fd9d00", "metadata": {}, "source": [ "Here we split the root (divisively building the tree from the top):" ] }, { "cell_type": "code", "execution_count": 6, "id": "4c8cb1b5-228a-4d53-aee4-7a2cd7eee96b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "frozenset({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "left, right = frozenset(range(0, 5)), frozenset(range(5, 10))\n", "\n", "# this should output the root\n", "tree.contract_nodes([left, right])" ] }, { "cell_type": "markdown", "id": "1eac789d-c003-4490-a485-013374670410", "metadata": {}, "source": [ "Note how we can declare `left` and `right` here as intermediates even though we don't know exactly how they will be formed yet. We can already work out certain properties of such nodes for example their indices ('legs') and their size:" ] }, { "cell_type": "code", "execution_count": 9, "id": "96586214-c7c8-4fc2-bdb7-f8f011a9f6b3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'b': 1, 'h': 1, 'a': 1, 'k': 1, 'g': 1, 'm': 1, 'o': 1}, 432)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.get_legs(left), tree.get_size(left)" ] }, { "cell_type": "markdown", "id": "b1e19a0f-f762-48c5-b4ed-83798f480b7f", "metadata": {}, "source": [ "But other information such as the flops required to form them need their children specified first." ] }, { "cell_type": "code", "execution_count": 10, "id": "34474da1-e7c1-4eb0-bd6c-c4d98e068084", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "432" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.get_flops(tree.root)" ] }, { "cell_type": "code", "execution_count": 11, "id": "43382b1f-c428-40df-acb5-5c2bd6e9e69d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abc,de->abcde'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.get_einsum_eq(parent)" ] }, { "cell_type": "markdown", "id": "9a77d4b9-b650-45eb-93d7-4204a091dbc3", "metadata": {}, "source": [ "The core tree information is stored as a mapping of children like so:" ] }, { "cell_type": "code", "execution_count": 12, "id": "fbb6a4a0-01eb-4d6b-a3de-b0ff434a12c0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{frozenset({0, 1}): (frozenset({1}), frozenset({0})),\n", " frozenset({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}): (frozenset({5, 6, 7, 8, 9}),\n", " frozenset({0, 1, 2, 3, 4}))}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.children" ] }, { "cell_type": "markdown", "id": "a950263b-d13e-430e-91d1-3a40767be94b", "metadata": {}, "source": [ "We can compute the flops and any other local information for any nodes that are keys of this. This information is cached in `tree.info`.\n", "\n", "You can get all remaining 'incomplete' nodes by calling [`tree.get_incomplete_nodes`](cotengra.core.ContractionTree.get_incomplete_nodes). This returns a dictionary, where each key is a 'childless' node, and each value is the list of 'parentless' nodes that are within its subgraph:" ] }, { "cell_type": "code", "execution_count": 14, "id": "e804ed26", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{frozenset({0, 1, 2, 3, 4}): [frozenset({2}),\n", " frozenset({3}),\n", " frozenset({4}),\n", " frozenset({0, 1})],\n", " frozenset({5, 6, 7, 8, 9}): [frozenset({5}),\n", " frozenset({6}),\n", " frozenset({7}),\n", " frozenset({8}),\n", " frozenset({9})]}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.get_incomplete_nodes()" ] }, { "cell_type": "markdown", "id": "93bcc9e2-44bb-4917-b6bf-fcebab029f4a", "metadata": {}, "source": [ "Each of these 'empty subtrees' needs to be filled in in order to complete the tree.\n", "\n", "When calling `contract_nodes` we can actually specify an arbitrary collection of (uncontracted) nodes, these can be seen as the leaves of a subtree which is a mini contraction problem in its own right. Hence the function itself takes the `optimize` kwarg used to specify how to optimize this sub contraction. Here we'll make use of this to complete our tree." ] }, { "cell_type": "code", "execution_count": 11, "id": "d908d0ca-5c76-47fa-9a56-f8592e298add", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fill in the subtree with `left` as root\n", "tree.contract_nodes([parent] + [frozenset([i]) for i in range(2, 5)])\n", "\n", "# fill in the subtree with `right` as root\n", "tree.contract_nodes([frozenset([i]) for i in range(5, 10)])\n", "\n", "# our tree should now be complete\n", "tree.is_complete()" ] }, { "cell_type": "markdown", "id": "c1e286bd-a5c6-4806-9c9d-870e7434b255", "metadata": {}, "source": [ "We can now plot our tree, check full costs, and use it perform contractions etc:" ] }, { "cell_type": "code", "execution_count": 12, "id": "401afda9-0f61-44d5-a260-d5a15fbdb801", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-02-15T17:56:11.709441\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.plot_flat()" ] }, { "cell_type": "code", "execution_count": 17, "id": "a5b9546b-f73b-4e4d-aac4-9e8fac5ea803", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-02-15T17:57:20.172825\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.plot_tent(order=True)" ] }, { "cell_type": "code", "execution_count": 18, "id": "992d941d-31f0-4e3d-afd6-6cc2d7a3d50d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(9.754887502163468, 21586.0)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.contraction_width(), tree.contraction_cost()" ] }, { "cell_type": "code", "execution_count": 19, "id": "15c000dd-b2a9-4159-b76c-8633f4130bae", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 104.78668702, 164.19133867],\n", " [ 161.34287932, 69.65644265],\n", " [-116.42788952, -46.96147145]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "arrays = [np.random.randn(*s) for s in shapes]\n", "tree.contract(arrays)" ] }, { "cell_type": "markdown", "id": "a4c0e8b7-f6a3-474b-8dfb-9f2d9c6bd792", "metadata": {}, "source": [ "We can also generate an explicit **contraction path**, which is a specific ordering of the tree (the default for exact contractions being depth first) and the format used by `opt_einsum` and accepted by `quimb` too:" ] }, { "cell_type": "code", "execution_count": 20, "id": "2ca5c822-0987-4b44-bdd2-42d11c5c48cf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((6, 8), (7, 8), (5, 6), (5, 6), (0, 1), (2, 4), (0, 1), (1, 2), (0, 1))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = tree.get_path()\n", "path" ] }, { "cell_type": "markdown", "id": "40de3597-6770-4187-8c57-b5757986e805", "metadata": {}, "source": [ "Once we have a complete tree, either by explicitly building it or more likely as returned by an optimizer, we can also *modify* in various ways as detailed below. " ] }, { "cell_type": "markdown", "id": "b0d63433-101d-474d-a13b-e29f77ffe163", "metadata": {}, "source": [ "## Index Slicing\n", "\n", "Slicing is the technique of choosing some indices to explicitly sum over rather than include as tensor dimensions - thereby takng indexed slices of those tensors. It is also known as **variable projection** and **bond cutting**. The overall effect is to turn a single contraction into many independent, easier contractions as depicted below:\n", "\n", "

\"cotengra\"

\n", "\n", "While slicing *every* index would be equivalent to performing the naive einsum and thus exponentially slower, carefully choosing just a few *can* result in little to no overhead. What you gain is:\n", "\n", "1. **Each contraction can require drastically less memory**\n", "2. **Each contraction can be performed in parallel**\n", "\n", "In general, one specifies slicing in the `HyperOptimizer` loop - see [Slicing and Subtree Reconfiguration](advanced.ipynb#slicing-and-subtree-reconfiguration) - this section covers manually slicing a tree and the details." ] }, { "cell_type": "code", "execution_count": 21, "id": "c50f5fff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sAOÆÍÐÑÒð,Üâ,eDWÀÇãçõöĈ,dmÉý,aÁĀ,BOÌÛó,fÔÙă,qÕÝáûąĆ,EKÍÞðú,yáñ÷,knoRÀÃ÷,csÙíþ,hPÊÕ×,lÐãî,S,FRý,hpyÂÞä,EØÜâñö,Hïøû,bvLÈ×þĀĂ,ÈÏå,oCQòôāĄą,biwGIà,æìòù,xMTî,gWÎÚùā,uNVYÊÏĂ,cìĈ,rTÁÃÝ,eAÉë,jtGÚüĄĆ,ÄÇéć,dv,pÄÓßÿ,Dí,zÅËåúă,juVXæ,Jïø,fqPÅÖè,HKÔõ,JSËÛë,awNÖßàéê,rxZ,lntIUÎÓä,gkzQXÂØèć,LUZÑêó,iÌüÿ,BCMç,FÆÒ,mYô->\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-02-15T17:57:31.897726\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create a reasonable large random contraction:\n", "inputs, output, shapes, size_dict = ctg.utils.rand_equation(n=50, reg=5, seed=0, d_max=2)\n", "arrays = [np.random.uniform(size=s) for s in shapes]\n", "\n", "# visualize it:\n", "print(ctg.utils.inputs_output_to_eq(inputs, output))\n", "ctg.HyperGraph(inputs, output, size_dict).plot()" ] }, { "cell_type": "code", "execution_count": 22, "id": "afbd6414", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(572054568704.0, 27.0)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# start with a simple greedy contraction path / tree\n", "tree = ctg.array_contract_tree(inputs, output, size_dict, optimize='greedy')\n", "\n", "# check the time and space requirements\n", "tree.contraction_cost(), tree.contraction_width()" ] }, { "cell_type": "markdown", "id": "e1d007b4", "metadata": {}, "source": [ "We can call the [`slice`](cotengra.core.ContractionTree.slice) method of a tree \n", "to slice it." ] }, { "cell_type": "code", "execution_count": 23, "id": "32431e39", "metadata": {}, "outputs": [], "source": [ "tree_s = tree.slice(target_size=2**20)" ] }, { "cell_type": "markdown", "id": "41b915d8", "metadata": {}, "source": [ "```{hint}\n", "This method and various others can also be called inplace with a trailing underscore.\n", "```" ] }, { "cell_type": "code", "execution_count": 24, "id": "b183dbd6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'M': SliceInfo(inner=True, ind='M', size=2, project=None),\n", " 'Q': SliceInfo(inner=True, ind='Q', size=2, project=None),\n", " 'Á': SliceInfo(inner=True, ind='Á', size=2, project=None),\n", " 'Ã': SliceInfo(inner=True, ind='Ã', size=2, project=None),\n", " 'Æ': SliceInfo(inner=True, ind='Æ', size=2, project=None),\n", " 'á': SliceInfo(inner=True, ind='á', size=2, project=None),\n", " 'ç': SliceInfo(inner=True, ind='ç', size=2, project=None),\n", " 'ñ': SliceInfo(inner=True, ind='ñ', size=2, project=None),\n", " 'ÿ': SliceInfo(inner=True, ind='ÿ', size=2, project=None),\n", " 'ą': SliceInfo(inner=True, ind='ą', size=2, project=None)},\n", " 1024)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_s.sliced_inds, tree_s.nslices" ] }, { "cell_type": "code", "execution_count": 25, "id": "3d1ada9d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(656181444608.0, 20.0)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# check the new time and space requirements\n", "tree_s.contraction_cost(), tree_s.contraction_width()" ] }, { "cell_type": "markdown", "id": "94d68c60", "metadata": {}, "source": [ "The tree now has 6 indices sliced out, and represents 64 individual \n", "contractions each requiring 32x less memory. The total cost has increased \n", "slightly, this ***'slicing overhead'*** we can check:" ] }, { "cell_type": "code", "execution_count": 26, "id": "fc709fd0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.147060928286249" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_s.contraction_cost() / tree.contraction_cost()" ] }, { "cell_type": "markdown", "id": "43b95971-2d94-438f-93ac-00e3d513441a", "metadata": {}, "source": [ "So less than 15% more floating point operations overall, (arising from \n", "redundantly repeated contractions).\n", "\n", "```{note}\n", "Internally, the tree constructs a [`SliceFinder`](cotengra.slicer.SliceFinder)\n", "object and uses it to search for good indices to slice. For full control, you \n", "could do this yourself and then call manually call \n", "[`remove_ind`](cotengra.core.ContractionTree.remove_ind) on the tree.\n", "```\n", "\n", "Once a [ContractionTree](cotengra.core.ContractionTree) has been sliced, the\n", "[contract](cotengra.core.ContractionTree.contract) method can be used to \n", "automatically perform the sliced contraction:" ] }, { "cell_type": "code", "execution_count": 27, "id": "3e874658-6956-4fe2-8e3c-1dee6d02a02f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1024/1024 [00:25<00:00, 40.04it/s]\n" ] }, { "data": { "text/plain": [ "5.725156944176958e+22" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_s.contract(arrays, progbar=True)" ] }, { "cell_type": "markdown", "id": "b542a047-96e7-4e2a-873b-e69f1fc14cad", "metadata": {}, "source": [ "See the [main contraction page](contraction.ipynb) for more details on\n", "performing the actual sliced contraction once the indices have been found." ] }, { "cell_type": "markdown", "id": "3eff51f2-dbde-4341-866b-a224e9d03f08", "metadata": {}, "source": [ "## Subtree Reconfiguration\n", "\n", "Any subtree of a contraction tree itself describes a smaller contraction, with \n", "the subtree leaves being the effective inputs (generally intermediate tensors) \n", "and the subtree root being the effective output (also generally an \n", "intermediate). One advantage of cotengra keeping an explicit representation of \n", "the contraction tree is that such subtrees can be easily selected and \n", "re-optimized as illustrated in the following schematic:\n", "\n", "

\"cotengra\"

\n", "\n", "(Note in general the subtree being optimized will be closer in size to 10 \n", "intermediates or so.)\n", "\n", "If we do this and improve the contraction cost of a **subtree** (e.g. by using an optimal contraction path), then the contraction cost of the **whole tree** is improved. Moreover we can iterate across many or all subtrees in a `ContractionTree`, reconfiguring them and thus potentially updating the entire tree in incremental 'local' steps.\n", "\n", "The method to call for this is \n", "[`subtree_reconfigure`](cotengra.core.ContractionTree.subtree_reconfigure)." ] }, { "cell_type": "code", "execution_count": 25, "id": "89b9db97", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "0it [00:00, ?it/s]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "log2[SIZE]: 32.00 log10[FLOPs]: 12.98: : 500it [00:00, 660.40it/s]\n" ] } ], "source": [ "# generate a tree\n", "inputs, output, shapes, size_dict = ctg.utils.lattice_equation([24, 30], d_max=2)\n", "tree = ctg.array_contract_tree(inputs, output, size_dict, optimize='greedy')\n", "\n", "# reconfigure it (call tree.subtree_reconfigure? to see many options)\n", "tree_r = tree.subtree_reconfigure(progbar=True)" ] }, { "cell_type": "code", "execution_count": 26, "id": "86400cbe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1253.3836847767188" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# check the speedup\n", "tree.total_flops() / tree_r.total_flops()" ] }, { "cell_type": "markdown", "id": "abb701ea", "metadata": {}, "source": [ "Since it is a local optimization it is possible to get stuck. \n", "[`subtree_reconfigure_forest`](cotengra.core.ContractionTree.subtree_reconfigure_forest)\n", "offers a basic stochastic search of multiple reconfigurations that can avoid this and also be easily parallelized:" ] }, { "cell_type": "code", "execution_count": 27, "id": "1decea62", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "log2[SIZE]: 32.00 log10[FLOPs]: 12.44: 100%|██████████| 10/10 [00:05<00:00, 1.80it/s]\n" ] } ], "source": [ "tree_f = tree.subtree_reconfigure_forest(progbar=True, num_trees=4)" ] }, { "cell_type": "code", "execution_count": 28, "id": "96c20678", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4312.69228449926" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# check the speedup\n", "tree.total_flops() / tree_f.total_flops()" ] }, { "cell_type": "markdown", "id": "238ef29f", "metadata": {}, "source": [ "So indeed a little better.\n", "\n", "Subtree reconfiguration is often powerful enough to allow even 'bad' initial paths (like those generated by `'greedy'` ) to become very high quality." ] }, { "cell_type": "markdown", "id": "4385f8e0-b233-43cb-a2c1-09316889d5c7", "metadata": {}, "source": [ "## Dynamic Slicing\n", "\n", "A powerful application for reconfiguration (first implemented in ['Classical Simulation of Quantum Supremacy Circuits'](https://arxiv.org/abs/2005.06787)) is to interleave it with *slicing*. Namely:\n", "\n", "1. Choose an index to slice\n", "2. Reconfigure subtrees to account for the slightly different TN structure without this index\n", "3. Check if the tree has reached a certain size, if not return to 1.\n", "\n", "In this way, the contraction tree is slowly updated to account for potentially many indices being sliced.\n", "\n", "For example imagine we wanted to slice the ``tree_f`` from above to achieve a maximum size of `2**28` (approx suitable for 8GB of memory). We could directly slice it without changing the tree structure at all:" ] }, { "cell_type": "code", "execution_count": 29, "id": "a02e567e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Τ': SliceInfo(inner=True, ind='Τ', size=2, project=None),\n", " 'ϟ': SliceInfo(inner=True, ind='ϟ', size=2, project=None),\n", " 'К': SliceInfo(inner=True, ind='К', size=2, project=None),\n", " 'ѕ': SliceInfo(inner=True, ind='ѕ', size=2, project=None)}" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_s = tree_f.slice(target_size=2**28)\n", "tree_s.sliced_inds" ] }, { "cell_type": "markdown", "id": "51e9b3f6", "metadata": {}, "source": [ "Or we could simultaneously interleave subtree reconfiguration:" ] }, { "cell_type": "code", "execution_count": 30, "id": "0b70016e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "log2[SIZE]: 28.00 log10[FLOPs]: 12.36: : 4it [00:04, 1.21s/it]\n" ] }, { "data": { "text/plain": [ "{'ì': SliceInfo(inner=True, ind='ì', size=2, project=None),\n", " 'ϟ': SliceInfo(inner=True, ind='ϟ', size=2, project=None),\n", " 'К': SliceInfo(inner=True, ind='К', size=2, project=None),\n", " 'ѕ': SliceInfo(inner=True, ind='ѕ', size=2, project=None)}" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_sr = tree_f.slice_and_reconfigure(target_size=2**28, progbar=True)\n", "tree_sr.sliced_inds" ] }, { "cell_type": "code", "execution_count": 31, "id": "f1580e24", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.6619162859973315" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_s.total_flops() / tree_sr.total_flops()" ] }, { "cell_type": "markdown", "id": "bef503d1", "metadata": {}, "source": [ "We can see it has achieved the target size with 1.5x better cost. There is also a 'forested' version of this algorithm which again performs a stochastic search of multiple possible slicing+reconfiguring options:" ] }, { "cell_type": "code", "execution_count": 32, "id": "89ef4271", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "log2[SIZE]: 28.00 log10[FLOPs]: 12.36: : 4it [00:08, 2.25s/it]\n" ] }, { "data": { "text/plain": [ "1.6619162861844943" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_fsr = tree_f.slice_and_reconfigure_forest(target_size=2**28, progbar=True)\n", "tree_s.total_flops() / tree_fsr.total_flops()" ] }, { "cell_type": "markdown", "id": "898c19e6", "metadata": {}, "source": [ "We can see here it has done a little better. The foresting looks roughly like the following:\n", "\n", "

\"cotengra\"

\n", "\n", "The subtree reconfiguration within the slicing can *itself be forested* for a doubly forested algorithm. This will give the highest quality (but also slowest) search.\n" ] }, { "cell_type": "code", "execution_count": 33, "id": "f941d6df", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/johnnie/Storage2TB/Sync/dev/python/cotengra/cotengra/parallel.py:276: UserWarning: Parallel specified but no existing global dask client found... created one (with 8 workers).\n", " warnings.warn(\n", "log2[SIZE]: 27.00 log10[FLOPs]: 11.55: : 4it [01:39, 24.76s/it]\n" ] } ], "source": [ "tree_fsfr = tree_f.slice_and_reconfigure_forest(\n", " target_size=2**28,\n", " num_trees=4,\n", " parallel='dask',\n", " reconf_opts=dict(\n", " subtree_size=12,\n", " forested=True,\n", " parallel='dask',\n", " num_trees=4,\n", " ),\n", " progbar=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "id": "aa9d385f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10.674935697795002" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_s.total_flops() / tree_fsfr.total_flops()" ] }, { "cell_type": "markdown", "id": "f0aea9b4", "metadata": {}, "source": [ "We've set the `subtree_size` here to `12` for higher quality reconfiguration, but reduced the `num_trees` in the forests (from default `8`) to `4` which will still lead to 4 x 4 = 16 trees being generated at each step. Again we see a slight improvement. This level of effort might only be required for very heavily slicing contraction trees, and in this case it might be best simply to trial many initial paths with a basic `slice_and_reconfigure`." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }