{
"cells": [
{
"cell_type": "markdown",
"id": "551a15b6",
"metadata": {},
"source": [
"(ex_extract_contraction)=\n",
"\n",
"# Extract contraction to matmuls only"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7261d7a0-bd84-4357-aeed-c72f0fb5774a",
"metadata": {},
"outputs": [],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"\n",
"import autoray as ar\n",
"\n",
"import cotengra as ctg"
]
},
{
"cell_type": "markdown",
"id": "aecf93ca",
"metadata": {},
"source": [
"Create a random contraction and contraction tree:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "e1f4edfc-b4f9-4760-87d0-45c335d71175",
"metadata": {},
"outputs": [],
"source": [
"inputs, output, shapes, size_dict = ctg.utils.rand_equation(\n",
" n=6,\n",
" reg=5,\n",
" n_out=1,\n",
" n_hyper_in=1,\n",
" n_hyper_out=1,\n",
" seed=42,\n",
")\n",
"\n",
"# square grid contraction:\n",
"# inputs, output, shapes, size_dict = ctg.utils.lattice_equation([3, 4])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "36c7a326-261f-4e8b-a834-7218824ffda0",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(, )"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree = ctg.array_contract_tree(\n",
" inputs, output, shapes=shapes, optimize=\"optimal\"\n",
")\n",
"tree.get_hypergraph().plot()"
]
},
{
"cell_type": "markdown",
"id": "d8e888d4",
"metadata": {},
"source": [
"The high level pairwise contractions can be shown as so:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "7f654565-45e6-44fe-8f27-394a45035b98",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[37m(0) cost: \u001b[39m3.2e+01 \u001b[37mwidths: \u001b[39m2.0,5.0->4.0 \u001b[37mtype: \u001b[39meinsum\n",
"\u001b[37minputs: \u001b[35m{b}\u001b[31m[i],\u001b[32ma\u001b[35m{b}\u001b[31m[i]\u001b[32mn\u001b[32mo\u001b[39m->\n",
"\u001b[37moutput: \u001b[35m{b}\u001b[32m(ano)\n",
"\n",
"\u001b[37m(1) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,5.0->5.0 \u001b[37mtype: \u001b[39meinsum\n",
"\u001b[37minputs: \u001b[35m{ba}\u001b[31m[n]\u001b[34mo,\u001b[35m{ab}\u001b[32md\u001b[31m[n]\u001b[32mh\u001b[39m->\n",
"\u001b[37moutput: \u001b[35m{ba}\u001b[34mo\u001b[32m(dh)\n",
"\n",
"\u001b[37m(2) cost: \u001b[39m7.7e+02 \u001b[37mwidths: \u001b[39m5.0,8.6->7.6 \u001b[37mtype: \u001b[39meinsum\n",
"\u001b[37minputs: \u001b[35m{ba}\u001b[34mo\u001b[31m[dh],\u001b[35m{ab}\u001b[32mc\u001b[31m[d]\u001b[32me\u001b[32mf\u001b[32mg\u001b[31m[h]\u001b[39m->\n",
"\u001b[37moutput: \u001b[35m{ba}\u001b[34mo\u001b[32m(cefg)\n",
"\n",
"\u001b[37m(3) cost: \u001b[39m1.5e+03 \u001b[37mwidths: \u001b[39m7.6,7.0->8.6 \u001b[37mtype: \u001b[39meinsum\n",
"\u001b[37minputs: \u001b[35m{ba}\u001b[31m[o]\u001b[34mc\u001b[34me\u001b[34mf\u001b[31m[g],\u001b[35m{ab}\u001b[32mk\u001b[31m[g]\u001b[32ml\u001b[31m[o]\u001b[32mm\u001b[39m->\n",
"\u001b[37moutput: \u001b[35m{ba}\u001b[34mc\u001b[34me\u001b[34mf\u001b[32m(klm)\n",
"\n",
"\u001b[37m(4) cost: \u001b[39m1.2e+03 \u001b[37mwidths: \u001b[39m8.6,9.2->2.6 \u001b[37mtype: \u001b[39mtensordot+perm\n",
"\u001b[37minputs: \u001b[31m[b]\u001b[34ma\u001b[31m[cefklm],\u001b[31m[b]\u001b[32mj\u001b[31m[cekflm]\u001b[39m->\n",
"\u001b[37moutput: \u001b[32m(j)\u001b[34ma\n",
"\n",
"\u001b[39m\n"
]
}
],
"source": [
"tree.print_contractions()"
]
},
{
"cell_type": "markdown",
"id": "aa3d8b14",
"metadata": {},
"source": [
"## Tracing the computational graph\n",
"\n",
"`cotengra` can also further break each `einsum` call down into *only* (batch) matrix multiplies, reshapes and transposes. To extract these we need two things:\n",
"\n",
"1. to trace the contraction with `autoray` lazy arrays.\n",
"2. to use `implementation=\"cotengra\"` to avoid using a backends `einsum` impl directly."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "4ae680f6-1a8f-4a48-b36d-8f9a465c5ec7",
"metadata": {},
"outputs": [],
"source": [
"variables = [ar.lazy.Variable(shape, backend=\"numpy\") for shape in shapes]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "df4fea46-854f-4a14-9357-6c352228760b",
"metadata": {},
"outputs": [],
"source": [
"lz = tree.contract(\n",
" variables,\n",
" # make cotengra use its own implementation of einsum/tensordot\n",
" # which breaks things down to bmm / reshape / transpose\n",
" implementation=\"cotengra\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "660ac484",
"metadata": {},
"source": [
"The remaining things are simply autoray related functionality. Visualize the operataions:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "bd7c9e22",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(, )"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lz.plot_circuit()"
]
},
{
"cell_type": "markdown",
"id": "ad415c45",
"metadata": {},
"source": [
"Visualize as text / terminal:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "79283c35",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 transpose[3, 2]\n",
" 1 ╰─matmul[2, 3]\n",
" 2 ├─reshape[192, 3]\n",
" 3 │ ╰─transpose[2, 2, 2, 3, 2, 2, 2, 3]\n",
" 4 │ ╰─←[2, 3, 2, 2, 2, 3, 2, 2]\n",
" 5 ╰─reshape[2, 192]\n",
" 6 ╰─transpose[2, 2, 2, 2, 3, 2, 2, 2]\n",
" 7 ╰─reshape[2, 2, 2, 2, 3, 2, 2, 2]\n",
" 8 ╰─matmul[4, 12, 8]\n",
" 9 ├─reshape[4, 4, 8]\n",
" 10 │ ╰─transpose[2, 2, 2, 2, 2, 2, 2]\n",
" 11 │ ╰─←[2, 2, 2, 2, 2, 2, 2]\n",
" 12 ╰─reshape[4, 12, 4]\n",
" 13 ╰─transpose[2, 2, 2, 2, 3, 2, 2]\n",
" 14 ╰─reshape[2, 2, 2, 2, 2, 3, 2]\n",
" 15 ╰─matmul[4, 2, 24]\n",
" 16 ├─reshape[4, 4, 24]\n",
" 17 │ ╰─transpose[2, 2, 2, 2, 2, 2, 3, 2]\n",
" 18 │ ╰─←[2, 2, 2, 2, 2, 3, 2, 2]\n",
" 19 ╰─reshape[4, 2, 4]\n",
" 20 ╰─matmul[4, 2, 4]\n",
" 21 ├─reshape[4, 2, 4]\n",
" 22 │ ╰─transpose[2, 2, 2, 2, 2]\n",
" 23 │ ╰─←[2, 2, 2, 2, 2]\n",
" 24 ╰─reshape[4, 2, 2]\n",
" 25 ╰─transpose[2, 2, 2, 2]\n",
" 26 ╰─reshape[2, 2, 2, 2]\n",
" 27 ╰─matmul[2, 1, 8]\n",
" 28 ├─reshape[2, 1, 2]\n",
" 29 │ ╰─←[2, 2]\n",
" 30 ╰─reshape[2, 2, 8]\n",
" 31 ╰─transpose[2, 2, 2, 2, 2]\n",
" 32 ╰─←[2, 2, 2, 2, 2]\n"
]
}
],
"source": [
"lz.show()"
]
},
{
"cell_type": "markdown",
"id": "ad23df09",
"metadata": {},
"source": [
"Some preparatory einsums might appear, only for removing initial diagonal indices and trivial indices etc."
]
},
{
"cell_type": "markdown",
"id": "b75c2f92",
"metadata": {},
"source": [
"## getting the linear form of the contraction"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "ee3860d8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['x139996605978592',\n",
" 'x139996605418560',\n",
" 'x139996605418080',\n",
" 'x139996605426720',\n",
" 'x139996605418752',\n",
" 'x139996605418272']"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# names of the inputs variables in the contraction\n",
"[f\"x{id(v)}\" for v in variables]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "035b84ec-ffec-45a6-badb-a6a069d7a7b1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x139996605419040 = transpose139997039584176(x139996605418272, (1, 2, 0, 3, 4,))\n",
"x139996603351968 = reshape139997039578096(x139996605419040, (2, 2, 8,))\n",
"del x139996605419040\n",
"x139996605426048 = reshape139997039578096(x139996605418560, (2, 1, 2,))\n",
"x139996603351488 = matmul139997611038160(x139996605426048, x139996603351968)\n",
"del x139996603351968\n",
"del x139996605426048\n",
"x139996603351872 = reshape139997039578096(x139996603351488, (2, 2, 2, 2,))\n",
"del x139996603351488\n",
"x139996603351776 = transpose139997039584176(x139996603351872, (0, 1, 3, 2,))\n",
"del x139996603351872\n",
"x139996603351680 = reshape139997039578096(x139996603351776, (4, 2, 2,))\n",
"del x139996603351776\n",
"x139996603351584 = transpose139997039584176(x139996605426720, (1, 0, 3, 2, 4,))\n",
"x139996603352160 = reshape139997039578096(x139996603351584, (4, 2, 4,))\n",
"del x139996603351584\n",
"x139996603352256 = matmul139997611038160(x139996603351680, x139996603352160)\n",
"del x139996603352160\n",
"del x139996603351680\n",
"x139996603352544 = reshape139997039578096(x139996603352256, (4, 2, 4,))\n",
"del x139996603352256\n",
"x139996603352640 = transpose139997039584176(x139996605978592, (1, 0, 3, 7, 2, 4, 5, 6,))\n",
"x139996603352736 = reshape139997039578096(x139996603352640, (4, 4, 24,))\n",
"del x139996603352640\n",
"x139996603352832 = matmul139997611038160(x139996603352544, x139996603352736)\n",
"del x139996603352736\n",
"del x139996603352544\n",
"x139996603352928 = reshape139997039578096(x139996603352832, (2, 2, 2, 2, 2, 3, 2,))\n",
"del x139996603352832\n",
"x139996603353216 = transpose139997039584176(x139996603352928, (0, 1, 3, 4, 5, 2, 6,))\n",
"del x139996603352928\n",
"x139996603353312 = reshape139997039578096(x139996603353216, (4, 12, 4,))\n",
"del x139996603353216\n",
"x139996603353504 = transpose139997039584176(x139996605418752, (1, 0, 5, 3, 2, 4, 6,))\n",
"x139996603353600 = reshape139997039578096(x139996603353504, (4, 4, 8,))\n",
"del x139996603353504\n",
"x139996603353696 = matmul139997611038160(x139996603353312, x139996603353600)\n",
"del x139996603353600\n",
"del x139996603353312\n",
"x139996603353792 = reshape139997039578096(x139996603353696, (2, 2, 2, 2, 3, 2, 2, 2,))\n",
"del x139996603353696\n",
"x139996603354176 = transpose139997039584176(x139996603353792, (1, 0, 2, 3, 4, 5, 6, 7,))\n",
"del x139996603353792\n",
"x139996603354272 = reshape139997039578096(x139996603354176, (2, 192,))\n",
"del x139996603354176\n",
"x139996603354368 = transpose139997039584176(x139996605418080, (0, 2, 3, 5, 4, 6, 7, 1,))\n",
"x139996603354464 = reshape139997039578096(x139996603354368, (192, 3,))\n",
"del x139996603354368\n",
"x139996603354560 = matmul139997611038160(x139996603354272, x139996603354464)\n",
"del x139996603354464\n",
"del x139996603354272\n",
"x139996603354656 = transpose139997039584176(x139996603354560, (1, 0,))\n",
"del x139996603354560\n"
]
}
],
"source": [
"# a python source code of the whole contraction\n",
"print(lz.get_source())"
]
},
{
"cell_type": "markdown",
"id": "574e460b",
"metadata": {},
"source": [
"Or access the nodes programmatically:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "192feabf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fn: None\n",
"args: None\n",
"kwargs: None\n",
"\n",
"fn: \n",
"args: (, (1, 2, 0, 3, 4))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (2, 2, 8))\n",
"kwargs: {}\n",
"\n",
"fn: None\n",
"args: None\n",
"kwargs: None\n",
"\n",
"fn: \n",
"args: (, (2, 1, 2))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, )\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (2, 2, 2, 2))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (0, 1, 3, 2))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (4, 2, 2))\n",
"kwargs: {}\n",
"\n",
"fn: None\n",
"args: None\n",
"kwargs: None\n",
"\n",
"fn: \n",
"args: (, (1, 0, 3, 2, 4))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (4, 2, 4))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, )\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (4, 2, 4))\n",
"kwargs: {}\n",
"\n",
"fn: None\n",
"args: None\n",
"kwargs: None\n",
"\n",
"fn: \n",
"args: (, (1, 0, 3, 7, 2, 4, 5, 6))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (4, 4, 24))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, )\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (2, 2, 2, 2, 2, 3, 2))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (0, 1, 3, 4, 5, 2, 6))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (4, 12, 4))\n",
"kwargs: {}\n",
"\n",
"fn: None\n",
"args: None\n",
"kwargs: None\n",
"\n",
"fn: \n",
"args: (, (1, 0, 5, 3, 2, 4, 6))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (4, 4, 8))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, )\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (2, 2, 2, 2, 3, 2, 2, 2))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (1, 0, 2, 3, 4, 5, 6, 7))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (2, 192))\n",
"kwargs: {}\n",
"\n",
"fn: None\n",
"args: None\n",
"kwargs: None\n",
"\n",
"fn: \n",
"args: (, (0, 2, 3, 5, 4, 6, 7, 1))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (192, 3))\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, )\n",
"kwargs: {}\n",
"\n",
"fn: \n",
"args: (, (1, 0))\n",
"kwargs: {}\n",
"\n"
]
}
],
"source": [
"for node in ar.lazy.ascend(lz):\n",
" print(f\"fn: {node.fn}\")\n",
" print(f\"args: {node.args}\")\n",
" print(f\"kwargs: {node.kwargs}\")\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}