{ "cells": [ { "cell_type": "markdown", "id": "551a15b6", "metadata": {}, "source": [ "(ex_extract_contraction)=\n", "\n", "# Extract contraction to matmuls only" ] }, { "cell_type": "code", "execution_count": 6, "id": "7261d7a0-bd84-4357-aeed-c72f0fb5774a", "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "\n", "import autoray as ar\n", "\n", "import cotengra as ctg" ] }, { "cell_type": "markdown", "id": "aecf93ca", "metadata": {}, "source": [ "Create a random contraction and contraction tree:" ] }, { "cell_type": "code", "execution_count": 22, "id": "e1f4edfc-b4f9-4760-87d0-45c335d71175", "metadata": {}, "outputs": [], "source": [ "inputs, output, shapes, size_dict = ctg.utils.rand_equation(\n", " n=6,\n", " reg=5,\n", " n_out=1,\n", " n_hyper_in=1,\n", " n_hyper_out=1,\n", " seed=42,\n", ")\n", "\n", "# square grid contraction:\n", "# inputs, output, shapes, size_dict = ctg.utils.lattice_equation([3, 4])" ] }, { "cell_type": "code", "execution_count": 12, "id": "36c7a326-261f-4e8b-a834-7218824ffda0", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree = ctg.array_contract_tree(\n", " inputs, output, shapes=shapes, optimize=\"optimal\"\n", ")\n", "tree.get_hypergraph().plot()" ] }, { "cell_type": "markdown", "id": "d8e888d4", "metadata": {}, "source": [ "The high level pairwise contractions can be shown as so:" ] }, { "cell_type": "code", "execution_count": 23, "id": "7f654565-45e6-44fe-8f27-394a45035b98", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[37m(0) cost: \u001b[39m3.2e+01 \u001b[37mwidths: \u001b[39m2.0,5.0->4.0 \u001b[37mtype: \u001b[39meinsum\n", "\u001b[37minputs: \u001b[35m{b}\u001b[31m[i],\u001b[32ma\u001b[35m{b}\u001b[31m[i]\u001b[32mn\u001b[32mo\u001b[39m->\n", "\u001b[37moutput: \u001b[35m{b}\u001b[32m(ano)\n", "\n", "\u001b[37m(1) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,5.0->5.0 \u001b[37mtype: \u001b[39meinsum\n", "\u001b[37minputs: \u001b[35m{ba}\u001b[31m[n]\u001b[34mo,\u001b[35m{ab}\u001b[32md\u001b[31m[n]\u001b[32mh\u001b[39m->\n", "\u001b[37moutput: \u001b[35m{ba}\u001b[34mo\u001b[32m(dh)\n", "\n", "\u001b[37m(2) cost: \u001b[39m7.7e+02 \u001b[37mwidths: \u001b[39m5.0,8.6->7.6 \u001b[37mtype: \u001b[39meinsum\n", "\u001b[37minputs: \u001b[35m{ba}\u001b[34mo\u001b[31m[dh],\u001b[35m{ab}\u001b[32mc\u001b[31m[d]\u001b[32me\u001b[32mf\u001b[32mg\u001b[31m[h]\u001b[39m->\n", "\u001b[37moutput: \u001b[35m{ba}\u001b[34mo\u001b[32m(cefg)\n", "\n", "\u001b[37m(3) cost: \u001b[39m1.5e+03 \u001b[37mwidths: \u001b[39m7.6,7.0->8.6 \u001b[37mtype: \u001b[39meinsum\n", "\u001b[37minputs: \u001b[35m{ba}\u001b[31m[o]\u001b[34mc\u001b[34me\u001b[34mf\u001b[31m[g],\u001b[35m{ab}\u001b[32mk\u001b[31m[g]\u001b[32ml\u001b[31m[o]\u001b[32mm\u001b[39m->\n", "\u001b[37moutput: \u001b[35m{ba}\u001b[34mc\u001b[34me\u001b[34mf\u001b[32m(klm)\n", "\n", "\u001b[37m(4) cost: \u001b[39m1.2e+03 \u001b[37mwidths: \u001b[39m8.6,9.2->2.6 \u001b[37mtype: \u001b[39mtensordot+perm\n", "\u001b[37minputs: \u001b[31m[b]\u001b[34ma\u001b[31m[cefklm],\u001b[31m[b]\u001b[32mj\u001b[31m[cekflm]\u001b[39m->\n", "\u001b[37moutput: \u001b[32m(j)\u001b[34ma\n", "\n", "\u001b[39m\n" ] } ], "source": [ "tree.print_contractions()" ] }, { "cell_type": "markdown", "id": "aa3d8b14", "metadata": {}, "source": [ "## Tracing the computational graph\n", "\n", "`cotengra` can also further break each `einsum` call down into *only* (batch) matrix multiplies, reshapes and transposes. To extract these we need two things:\n", "\n", "1. to trace the contraction with `autoray` lazy arrays.\n", "2. to use `implementation=\"cotengra\"` to avoid using a backends `einsum` impl directly." ] }, { "cell_type": "code", "execution_count": 14, "id": "4ae680f6-1a8f-4a48-b36d-8f9a465c5ec7", "metadata": {}, "outputs": [], "source": [ "variables = [ar.lazy.Variable(shape, backend=\"numpy\") for shape in shapes]" ] }, { "cell_type": "code", "execution_count": 15, "id": "df4fea46-854f-4a14-9357-6c352228760b", "metadata": {}, "outputs": [], "source": [ "lz = tree.contract(\n", " variables,\n", " # make cotengra use its own implementation of einsum/tensordot\n", " # which breaks things down to bmm / reshape / transpose\n", " implementation=\"cotengra\",\n", ")" ] }, { "cell_type": "markdown", "id": "660ac484", "metadata": {}, "source": [ "The remaining things are simply autoray related functionality. Visualize the operataions:" ] }, { "cell_type": "code", "execution_count": 16, "id": "bd7c9e22", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/plain": [ "(
, )" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lz.plot_circuit()" ] }, { "cell_type": "markdown", "id": "ad415c45", "metadata": {}, "source": [ "Visualize as text / terminal:" ] }, { "cell_type": "code", "execution_count": 17, "id": "79283c35", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0 transpose[3, 2]\n", " 1 ╰─matmul[2, 3]\n", " 2 ├─reshape[192, 3]\n", " 3 │ ╰─transpose[2, 2, 2, 3, 2, 2, 2, 3]\n", " 4 │ ╰─←[2, 3, 2, 2, 2, 3, 2, 2]\n", " 5 ╰─reshape[2, 192]\n", " 6 ╰─transpose[2, 2, 2, 2, 3, 2, 2, 2]\n", " 7 ╰─reshape[2, 2, 2, 2, 3, 2, 2, 2]\n", " 8 ╰─matmul[4, 12, 8]\n", " 9 ├─reshape[4, 4, 8]\n", " 10 │ ╰─transpose[2, 2, 2, 2, 2, 2, 2]\n", " 11 │ ╰─←[2, 2, 2, 2, 2, 2, 2]\n", " 12 ╰─reshape[4, 12, 4]\n", " 13 ╰─transpose[2, 2, 2, 2, 3, 2, 2]\n", " 14 ╰─reshape[2, 2, 2, 2, 2, 3, 2]\n", " 15 ╰─matmul[4, 2, 24]\n", " 16 ├─reshape[4, 4, 24]\n", " 17 │ ╰─transpose[2, 2, 2, 2, 2, 2, 3, 2]\n", " 18 │ ╰─←[2, 2, 2, 2, 2, 3, 2, 2]\n", " 19 ╰─reshape[4, 2, 4]\n", " 20 ╰─matmul[4, 2, 4]\n", " 21 ├─reshape[4, 2, 4]\n", " 22 │ ╰─transpose[2, 2, 2, 2, 2]\n", " 23 │ ╰─←[2, 2, 2, 2, 2]\n", " 24 ╰─reshape[4, 2, 2]\n", " 25 ╰─transpose[2, 2, 2, 2]\n", " 26 ╰─reshape[2, 2, 2, 2]\n", " 27 ╰─matmul[2, 1, 8]\n", " 28 ├─reshape[2, 1, 2]\n", " 29 │ ╰─←[2, 2]\n", " 30 ╰─reshape[2, 2, 8]\n", " 31 ╰─transpose[2, 2, 2, 2, 2]\n", " 32 ╰─←[2, 2, 2, 2, 2]\n" ] } ], "source": [ "lz.show()" ] }, { "cell_type": "markdown", "id": "ad23df09", "metadata": {}, "source": [ "Some preparatory einsums might appear, only for removing initial diagonal indices and trivial indices etc." ] }, { "cell_type": "markdown", "id": "b75c2f92", "metadata": {}, "source": [ "## getting the linear form of the contraction" ] }, { "cell_type": "code", "execution_count": 18, "id": "ee3860d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['x139996605978592',\n", " 'x139996605418560',\n", " 'x139996605418080',\n", " 'x139996605426720',\n", " 'x139996605418752',\n", " 'x139996605418272']" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# names of the inputs variables in the contraction\n", "[f\"x{id(v)}\" for v in variables]" ] }, { "cell_type": "code", "execution_count": 19, "id": "035b84ec-ffec-45a6-badb-a6a069d7a7b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x139996605419040 = transpose139997039584176(x139996605418272, (1, 2, 0, 3, 4,))\n", "x139996603351968 = reshape139997039578096(x139996605419040, (2, 2, 8,))\n", "del x139996605419040\n", "x139996605426048 = reshape139997039578096(x139996605418560, (2, 1, 2,))\n", "x139996603351488 = matmul139997611038160(x139996605426048, x139996603351968)\n", "del x139996603351968\n", "del x139996605426048\n", "x139996603351872 = reshape139997039578096(x139996603351488, (2, 2, 2, 2,))\n", "del x139996603351488\n", "x139996603351776 = transpose139997039584176(x139996603351872, (0, 1, 3, 2,))\n", "del x139996603351872\n", "x139996603351680 = reshape139997039578096(x139996603351776, (4, 2, 2,))\n", "del x139996603351776\n", "x139996603351584 = transpose139997039584176(x139996605426720, (1, 0, 3, 2, 4,))\n", "x139996603352160 = reshape139997039578096(x139996603351584, (4, 2, 4,))\n", "del x139996603351584\n", "x139996603352256 = matmul139997611038160(x139996603351680, x139996603352160)\n", "del x139996603352160\n", "del x139996603351680\n", "x139996603352544 = reshape139997039578096(x139996603352256, (4, 2, 4,))\n", "del x139996603352256\n", "x139996603352640 = transpose139997039584176(x139996605978592, (1, 0, 3, 7, 2, 4, 5, 6,))\n", "x139996603352736 = reshape139997039578096(x139996603352640, (4, 4, 24,))\n", "del x139996603352640\n", "x139996603352832 = matmul139997611038160(x139996603352544, x139996603352736)\n", "del x139996603352736\n", "del x139996603352544\n", "x139996603352928 = reshape139997039578096(x139996603352832, (2, 2, 2, 2, 2, 3, 2,))\n", "del x139996603352832\n", "x139996603353216 = transpose139997039584176(x139996603352928, (0, 1, 3, 4, 5, 2, 6,))\n", "del x139996603352928\n", "x139996603353312 = reshape139997039578096(x139996603353216, (4, 12, 4,))\n", "del x139996603353216\n", "x139996603353504 = transpose139997039584176(x139996605418752, (1, 0, 5, 3, 2, 4, 6,))\n", "x139996603353600 = reshape139997039578096(x139996603353504, (4, 4, 8,))\n", "del x139996603353504\n", "x139996603353696 = matmul139997611038160(x139996603353312, x139996603353600)\n", "del x139996603353600\n", "del x139996603353312\n", "x139996603353792 = reshape139997039578096(x139996603353696, (2, 2, 2, 2, 3, 2, 2, 2,))\n", "del x139996603353696\n", "x139996603354176 = transpose139997039584176(x139996603353792, (1, 0, 2, 3, 4, 5, 6, 7,))\n", "del x139996603353792\n", "x139996603354272 = reshape139997039578096(x139996603354176, (2, 192,))\n", "del x139996603354176\n", "x139996603354368 = transpose139997039584176(x139996605418080, (0, 2, 3, 5, 4, 6, 7, 1,))\n", "x139996603354464 = reshape139997039578096(x139996603354368, (192, 3,))\n", "del x139996603354368\n", "x139996603354560 = matmul139997611038160(x139996603354272, x139996603354464)\n", "del x139996603354464\n", "del x139996603354272\n", "x139996603354656 = transpose139997039584176(x139996603354560, (1, 0,))\n", "del x139996603354560\n" ] } ], "source": [ "# a python source code of the whole contraction\n", "print(lz.get_source())" ] }, { "cell_type": "markdown", "id": "574e460b", "metadata": {}, "source": [ "Or access the nodes programmatically:" ] }, { "cell_type": "code", "execution_count": 20, "id": "192feabf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fn: None\n", "args: None\n", "kwargs: None\n", "\n", "fn: \n", "args: (, (1, 2, 0, 3, 4))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (2, 2, 8))\n", "kwargs: {}\n", "\n", "fn: None\n", "args: None\n", "kwargs: None\n", "\n", "fn: \n", "args: (, (2, 1, 2))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, )\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (2, 2, 2, 2))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (0, 1, 3, 2))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (4, 2, 2))\n", "kwargs: {}\n", "\n", "fn: None\n", "args: None\n", "kwargs: None\n", "\n", "fn: \n", "args: (, (1, 0, 3, 2, 4))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (4, 2, 4))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, )\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (4, 2, 4))\n", "kwargs: {}\n", "\n", "fn: None\n", "args: None\n", "kwargs: None\n", "\n", "fn: \n", "args: (, (1, 0, 3, 7, 2, 4, 5, 6))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (4, 4, 24))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, )\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (2, 2, 2, 2, 2, 3, 2))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (0, 1, 3, 4, 5, 2, 6))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (4, 12, 4))\n", "kwargs: {}\n", "\n", "fn: None\n", "args: None\n", "kwargs: None\n", "\n", "fn: \n", "args: (, (1, 0, 5, 3, 2, 4, 6))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (4, 4, 8))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, )\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (2, 2, 2, 2, 3, 2, 2, 2))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (1, 0, 2, 3, 4, 5, 6, 7))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (2, 192))\n", "kwargs: {}\n", "\n", "fn: None\n", "args: None\n", "kwargs: None\n", "\n", "fn: \n", "args: (, (0, 2, 3, 5, 4, 6, 7, 1))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (192, 3))\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, )\n", "kwargs: {}\n", "\n", "fn: \n", "args: (, (1, 0))\n", "kwargs: {}\n", "\n" ] } ], "source": [ "for node in ar.lazy.ascend(lz):\n", " print(f\"fn: {node.fn}\")\n", " print(f\"args: {node.args}\")\n", " print(f\"kwargs: {node.kwargs}\")\n", " print()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3" } }, "nbformat": 4, "nbformat_minor": 5 }