{
"cells": [
{
"cell_type": "markdown",
"id": "e07b8472-d0ed-4d9f-b35d-9d3d515e9823",
"metadata": {},
"source": [
"# Visualization\n",
"\n",
"Various aspects of the contraction process can be visualized using the following functions."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "11a44ba1-85c4-4194-a6e7-287393f7f4c3",
"metadata": {},
"outputs": [],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"import cotengra as ctg"
]
},
{
"cell_type": "markdown",
"id": "630b9be5",
"metadata": {},
"source": [
"## Hypergraph visualization\n",
"\n",
"The first visualization we can do is to visualize the hypergraph corresponding to the geometry of the tensor network or equation with [`HyperGraph.plot`](cotengra.plot.plot_hypergraph). Hyper edges are shown as a zero size vertex connecting the tensors they appear on, like a COPY-tensor:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "af822cae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"abce,bdf,bdef,a,b,ab->ca\n"
]
},
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"inputs, output, shapes, size_dict = ctg.utils.rand_equation(\n",
" 6,\n",
" 2,\n",
" n_out=1,\n",
" n_hyper_in=1,\n",
" n_hyper_out=1,\n",
" seed=4,\n",
")\n",
"print(ctg.utils.inputs_output_to_eq(inputs, output))\n",
"ctg.get_hypergraph(inputs, output, size_dict).plot();"
]
},
{
"cell_type": "markdown",
"id": "510c3048",
"metadata": {},
"source": [
"This simple contraction has five types of index:\n",
"\n",
"1. standard inner index - 'e' - which appear on exactly two tensors\n",
"2. standard inner multi-indices - 'd', 'f' - which both appear on the same two tensors\n",
"3. standard outer index - 'c' - which appears on exactly one tensor and the output\n",
"4. hyper inner index - 'b' - which appears on more than two tensors\n",
"5. hyper outer index - 'a' - which appears on multiple tensors and the output\n",
"\n",
"The nodes and indices are assigned unique colors by default, with hyper indices\n",
"shown as dashed lines."
]
},
{
"cell_type": "markdown",
"id": "2f49334a",
"metadata": {},
"source": [
"### Small tree visualization"
]
},
{
"cell_type": "markdown",
"id": "634f6ed1",
"metadata": {},
"source": [
"If the network is small enough and we have a [`ContractionTree`](cotengra.core.ContractionTree) for it, we can also visualize its entirety including all indices involved at each intermediate contraction using the [`ContractionTree.plot_flat`](cotengra.plot.plot_tree_flat) method:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "43957aeb",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tree = ctg.array_contract_tree(inputs, output, size_dict)\n",
"tree.plot_flat();"
]
},
{
"cell_type": "markdown",
"id": "c9056c49",
"metadata": {},
"source": [
"Here the unique node and index coloring by default matches that of the default hypergraph visualization.\n",
"The contraction flows from bottom to top."
]
},
{
"cell_type": "markdown",
"id": "4745beb1",
"metadata": {},
"source": [
"For the remaining examples, we'll generate a larger 2D lattice contraction, with no outputs or hyper indices:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "18580c42",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# generate an equation representing a 2D lattice\n",
"inputs, output, shapes, size_dict = ctg.utils.lattice_equation([5, 6])\n",
"hg = ctg.get_hypergraph(inputs, output, size_dict)\n",
"hg.plot(draw_edge_labels=True);"
]
},
{
"cell_type": "markdown",
"id": "1c92b55c",
"metadata": {},
"source": [
"You can turn off the node/edge coloring or set the node coloring to a simple centrality measure like so:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "05b28e38",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"hg.plot(node_color=\"centrality\", edge_color=False);"
]
},
{
"cell_type": "markdown",
"id": "7daa6358",
"metadata": {},
"source": [
"## Optimizer trials visualization\n",
"\n",
"If we run an hyper optimizer, we can visualize how the scores progress with trials using the\n",
"[`HyperOptimizer.plot_trials`](cotengra.plot.plot_trials) method:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7b44f879",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/128 [00:00, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"F=3.56 C=4.82 S=6.00 P=8.51: 100%|██████████| 128/128 [00:03<00:00, 33.01it/s]\n"
]
}
],
"source": [
"opt = ctg.HyperOptimizer(methods=[\"greedy\", \"kahypar\", \"labels\"], progbar=True)\n",
"\n",
"# run it and generate a tree\n",
"tree = opt.search(inputs, output, size_dict)"
]
},
{
"cell_type": "markdown",
"id": "cf738b6d",
"metadata": {},
"source": [
"By default the y-axis is the objective score, but you can specify e.g. `'flops'`':"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3e278770-a82d-4178-a57a-542e67b2eecf",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"opt.plot_trials(y=\"flops\");"
]
},
{
"cell_type": "markdown",
"id": "ab124c34",
"metadata": {},
"source": [
"Similarly, you can supply `x=\"time\"` to plot the scores as a function of cumulative CPU time.\n",
"\n",
"We can also plot the distribution of contraction costs against contraction widths using\n",
"the [`HyperOptimizer.plot_scatter`](cotengra.plot.plot_scatter) method:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "28974b9a",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"opt.plot_scatter(x=\"size\", y=\"flops\");"
]
},
{
"cell_type": "markdown",
"id": "9f453a6c",
"metadata": {},
"source": [
"You can examine the actual distribution of parameters chosen for each method with `HyperOptimizer.plot_parameters_parallel`:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8e3fab65",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"opt.plot_parameters_parallel(\"greedy\");"
]
},
{
"cell_type": "markdown",
"id": "c3f270e2",
"metadata": {},
"source": [
"## Large Tree visualizations\n",
"\n",
"The following visualization functions are available for inspecting a single,\n",
"complete [`ContractionTree`](cotengra.core.ContractionTree) once generated.\n",
"They mostly wrap [`plot_tree`](cotengra.plot.plot_tree), where you can see\n",
"most of the extra options."
]
},
{
"cell_type": "markdown",
"id": "ddfe5672",
"metadata": {},
"source": [
"### Contractions\n",
"\n",
"[`tree.plot_contractions`](cotengra.plot.plot_contractions)\n",
"gives you an overview of the memory and costs throughout the contraction:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "27532d10",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tree.plot_contractions();"
]
},
{
"cell_type": "markdown",
"id": "8179598e",
"metadata": {},
"source": [
"Here, `peak` is the memory required for all intermediates to be stored at once,\n",
"`write` is the size of the new intermedite tensor, the max of which is the\n",
"*contraction width*. `cost` is the scalar operations of each contraction.\n",
"\n",
"The list of corresponding pairwise contractions can be explicitly shown with the [`print_contractions`](cotengra.core.ContractionTree.print_contractions) method:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1297354c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[37m(0) cost: \u001b[39m1.6e+01 \u001b[37mwidths: \u001b[39m3.0,2.0->3.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mi\u001b[34mh\u001b[31m[j],\u001b[32mk\u001b[31m[j]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mi\u001b[34mh\u001b[32m(k)\n",
"\n",
"\u001b[37m(1) cost: \u001b[39m3.2e+01 \u001b[37mwidths: \u001b[39m3.0,3.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mi\u001b[34mh\u001b[31m[k],\u001b[31m[k]\u001b[32mv\u001b[32mu\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mi\u001b[34mh\u001b[32m(vu)\n",
"\n",
"\u001b[37m(2) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,4.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[i]\u001b[34mh\u001b[34mv\u001b[31m[u],\u001b[31m[i]\u001b[32mt\u001b[32ms\u001b[31m[u]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mh\u001b[34mv\u001b[32m(ts)\n",
"\n",
"\u001b[37m(3) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,3.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mh\u001b[31m[v]\u001b[34mt\u001b[34ms,\u001b[31m[v]\u001b[32mG\u001b[32mF\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mh\u001b[34mt\u001b[34ms\u001b[32m(GF)\n",
"\n",
"\u001b[37m(4) cost: \u001b[39m1.3e+02 \u001b[37mwidths: \u001b[39m5.0,4.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mh\u001b[31m[t]\u001b[34ms\u001b[34mG\u001b[31m[F],\u001b[31m[t]\u001b[32mE\u001b[32mD\u001b[31m[F]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mh\u001b[34ms\u001b[34mG\u001b[32m(ED)\n",
"\n",
"\u001b[37m(5) cost: \u001b[39m1.6e+01 \u001b[37mwidths: \u001b[39m3.0,2.0->3.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mP\u001b[34mV\u001b[31m[W],\u001b[32mR\u001b[31m[W]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mP\u001b[34mV\u001b[32m(R)\n",
"\n",
"\u001b[37m(6) cost: \u001b[39m3.2e+01 \u001b[37mwidths: \u001b[39m3.0,3.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mP\u001b[34mV\u001b[31m[R],\u001b[32mG\u001b[31m[R]\u001b[32mQ\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mP\u001b[34mV\u001b[32m(GQ)\n",
"\n",
"\u001b[37m(7) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,4.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[P]\u001b[34mV\u001b[34mG\u001b[31m[Q],\u001b[32mE\u001b[31m[P]\u001b[32mO\u001b[31m[Q]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mV\u001b[34mG\u001b[32m(EO)\n",
"\n",
"\u001b[37m(8) cost: \u001b[39m1.3e+02 \u001b[37mwidths: \u001b[39m5.0,4.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mh\u001b[34ms\u001b[31m[GE]\u001b[34mD,\u001b[32mV\u001b[31m[GE]\u001b[32mO\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mh\u001b[34ms\u001b[34mD\u001b[32m(VO)\n",
"\n",
"\u001b[37m(9) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,3.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mC\u001b[31m[N]\u001b[34mM\u001b[34mO,\u001b[31m[N]\u001b[32mU\u001b[32mV\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mC\u001b[34mM\u001b[34mO\u001b[32m(UV)\n",
"\n",
"\u001b[37m(10) cost: \u001b[39m2.6e+02 \u001b[37mwidths: \u001b[39m5.0,5.0->6.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mh\u001b[34ms\u001b[34mD\u001b[31m[VO],\u001b[32mC\u001b[32mM\u001b[31m[O]\u001b[32mU\u001b[31m[V]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mh\u001b[34ms\u001b[34mD\u001b[32m(CMU)\n",
"\n",
"\u001b[37m(11) cost: \u001b[39m2.6e+02 \u001b[37mwidths: \u001b[39m6.0,4.0->6.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mh\u001b[34ms\u001b[31m[DC]\u001b[34mM\u001b[34mU,\u001b[32mr\u001b[31m[C]\u001b[32mB\u001b[31m[D]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mh\u001b[34ms\u001b[34mM\u001b[34mU\u001b[32m(rB)\n",
"\n",
"\u001b[37m(12) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m3.0,4.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[g]\u001b[34mf\u001b[34mh,\u001b[31m[g]\u001b[32mr\u001b[32mq\u001b[32ms\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mf\u001b[34mh\u001b[32m(rqs)\n",
"\n",
"\u001b[37m(13) cost: \u001b[39m2.6e+02 \u001b[37mwidths: \u001b[39m6.0,5.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[hs]\u001b[34mM\u001b[34mU\u001b[31m[r]\u001b[34mB,\u001b[32mf\u001b[31m[hr]\u001b[32mq\u001b[31m[s]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mM\u001b[34mU\u001b[34mB\u001b[32m(fq)\n",
"\n",
"\u001b[37m(14) cost: \u001b[39m1.6e+01 \u001b[37mwidths: \u001b[39m3.0,2.0->3.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mw\u001b[31m[H]\u001b[34mI,\u001b[31m[H]\u001b[32mS\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mw\u001b[34mI\u001b[32m(S)\n",
"\n",
"\u001b[37m(15) cost: \u001b[39m3.2e+01 \u001b[37mwidths: \u001b[39m3.0,3.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mw\u001b[34mI\u001b[31m[S],\u001b[32mJ\u001b[31m[S]\u001b[32mT\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mw\u001b[34mI\u001b[32m(JT)\n",
"\n",
"\u001b[37m(16) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,4.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mw\u001b[31m[IJ]\u001b[34mT,\u001b[32my\u001b[31m[JI]\u001b[32mK\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mw\u001b[34mT\u001b[32m(yK)\n",
"\n",
"\u001b[37m(17) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,3.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[w]\u001b[34mT\u001b[34my\u001b[34mK,\u001b[32ml\u001b[31m[w]\u001b[32mx\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mT\u001b[34my\u001b[34mK\u001b[32m(lx)\n",
"\n",
"\u001b[37m(18) cost: \u001b[39m1.3e+02 \u001b[37mwidths: \u001b[39m5.0,4.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mT\u001b[31m[y]\u001b[34mK\u001b[34ml\u001b[31m[x],\u001b[32mn\u001b[31m[yx]\u001b[32mz\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mT\u001b[34mK\u001b[34ml\u001b[32m(nz)\n",
"\n",
"\u001b[37m(19) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,3.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mA\u001b[31m[L]\u001b[34mK\u001b[34mM,\u001b[31m[L]\u001b[32mT\u001b[32mU\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mA\u001b[34mK\u001b[34mM\u001b[32m(TU)\n",
"\n",
"\u001b[37m(20) cost: \u001b[39m2.6e+02 \u001b[37mwidths: \u001b[39m5.0,5.0->6.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[TK]\u001b[34ml\u001b[34mn\u001b[34mz,\u001b[32mA\u001b[31m[K]\u001b[32mM\u001b[31m[T]\u001b[32mU\u001b[39m->\n",
"\u001b[37moutput: \u001b[34ml\u001b[34mn\u001b[34mz\u001b[32m(AMU)\n",
"\n",
"\u001b[37m(21) cost: \u001b[39m2.6e+02 \u001b[37mwidths: \u001b[39m6.0,4.0->6.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34ml\u001b[34mn\u001b[31m[zA]\u001b[34mM\u001b[34mU,\u001b[32mp\u001b[31m[Az]\u001b[32mB\u001b[39m->\n",
"\u001b[37moutput: \u001b[34ml\u001b[34mn\u001b[34mM\u001b[34mU\u001b[32m(pB)\n",
"\n",
"\u001b[37m(22) cost: \u001b[39m2.6e+02 \u001b[37mwidths: \u001b[39m5.0,6.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[MUB]\u001b[34mf\u001b[34mq,\u001b[32ml\u001b[32mn\u001b[31m[MU]\u001b[32mp\u001b[31m[B]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mf\u001b[34mq\u001b[32m(lnp)\n",
"\n",
"\u001b[37m(23) cost: \u001b[39m1.3e+02 \u001b[37mwidths: \u001b[39m5.0,4.0->5.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34mf\u001b[31m[q]\u001b[34ml\u001b[34mn\u001b[31m[p],\u001b[32me\u001b[31m[p]\u001b[32mo\u001b[31m[q]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34mf\u001b[34ml\u001b[34mn\u001b[32m(eo)\n",
"\n",
"\u001b[37m(24) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m5.0,3.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[f]\u001b[34ml\u001b[34mn\u001b[31m[e]\u001b[34mo,\u001b[31m[e]\u001b[32md\u001b[31m[f]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34ml\u001b[34mn\u001b[34mo\u001b[32m(d)\n",
"\n",
"\u001b[37m(25) cost: \u001b[39m6.4e+01 \u001b[37mwidths: \u001b[39m4.0,4.0->4.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[34ml\u001b[31m[no]\u001b[34md,\u001b[32mc\u001b[31m[n]\u001b[32mm\u001b[31m[o]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34ml\u001b[34md\u001b[32m(cm)\n",
"\n",
"\u001b[37m(26) cost: \u001b[39m3.2e+01 \u001b[37mwidths: \u001b[39m4.0,3.0->3.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[l]\u001b[34md\u001b[34mc\u001b[31m[m],\u001b[32ma\u001b[31m[lm]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34md\u001b[34mc\u001b[32m(a)\n",
"\n",
"\u001b[37m(27) cost: \u001b[39m1.6e+01 \u001b[37mwidths: \u001b[39m3.0,3.0->2.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[dc]\u001b[34ma,\u001b[31m[c]\u001b[32mb\u001b[31m[d]\u001b[39m->\n",
"\u001b[37moutput: \u001b[34ma\u001b[32m(b)\n",
"\n",
"\u001b[37m(28) cost: \u001b[39m4.0e+00 \u001b[37mwidths: \u001b[39m2.0,2.0->0.0 \u001b[37mtype: \u001b[39mtensordot\n",
"\u001b[37minputs: \u001b[31m[ab],\u001b[31m[ab]\u001b[39m->\n",
"\u001b[37moutput: \n",
"\n",
"\u001b[39m\n"
]
}
],
"source": [
"tree.print_contractions()"
]
},
{
"cell_type": "markdown",
"id": "f2726b0d",
"metadata": {},
"source": [
"The indices are colored according to:\n",
"\n",
"1. blue - appears on left input and is kept\n",
"2. green - appears on right input and is kept\n",
"3. red - appears on both inputs and is contracted away\n",
"4. purple - appears on both inputs and is kept (a 'batch' or hyper index)"
]
},
{
"cell_type": "markdown",
"id": "a600a924",
"metadata": {},
"source": [
"### Tent\n",
"\n",
"The most general purpose visualization for the structure of a\n",
"[`ContractionTree`](cotengra.core.ContractionTree) is\n",
"the [`ContractionTree.plot_tent`](cotengra.plot.plot_tent) method.\n",
"This plots the input network (in grey) at the bottom, and the contraction tree\n",
"intermediates laid out above. The width and color of the tree edges denote the\n",
"intermediate tensor widths, and the size and color of the tree nodes denote the\n",
"FLOPs required to contract each intermediate tensor:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cf698cfd",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tree.plot_tent();"
]
},
{
"cell_type": "markdown",
"id": "171ab2f0",
"metadata": {},
"source": [
"If you supply `order=True` then the intermediate nodes will be in the exact vertical order than they would be performed:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c957d6bb",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tree.plot_tent(order=True);"
]
},
{
"cell_type": "markdown",
"id": "1421bc44",
"metadata": {},
"source": [
"```{note}\n",
"If you have sliced indices, these will appear as dashed lines in the input graph.\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "ccbc977d",
"metadata": {},
"source": [
"### Circuit\n",
"\n",
"If you want to plot only the tree with an emphasis on the order of operations\n",
"then you can use the [`ContractionTree.plot_circuit`](cotengra.plot.plot_tree_circuit)\n",
"method:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "363fd734",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tree.plot_circuit();"
]
},
{
"cell_type": "markdown",
"id": "0e457bf8",
"metadata": {},
"source": [
"### Ring\n",
"\n",
"Another option is the [`ContractionTree.plot_ring`](cotengra.plot.plot_ring)\n",
"method which lays out the input network on a ring, with the contraction\n",
"tree intermediates laid out towards the center. The more arcs cross between\n",
"branches the more expensive that contraction. This can be useful for\n",
"inspecting how many 'spines' a contraction has or how *balanced* it is:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b62d5f20",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tree.plot_ring();"
]
},
{
"cell_type": "markdown",
"id": "753f1a4d",
"metadata": {},
"source": [
"### Rubberband\n",
"\n",
"For small and close to planar graphs, an alternative visualization is the\n",
"[`ContractionTree.plot_rubberband`](cotengra.plot.plot_rubberband) method.\n",
"method using [`quimb`](https://github.com/jcmgray/quimb). Here, nodes of the input graph are hierarchically grouped into bands\n",
"according to the contraction tree. The order of contraction is represented by\n",
"the colors:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "352a5af3",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tree.plot_rubberband();"
]
},
{
"cell_type": "markdown",
"id": "ed653251",
"metadata": {},
"source": [
"All of the above methods can be pretty extensively customized, including by\n",
"supplying custom colormaps. They also return `(fig, ax)` for further\n",
"customization or embedding in other plots."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "a8610954",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"log10[FLOPS]=11.54 log10[COMBO]=11.99 log2[SIZE]=31 log2[PEAK]=32: 100%|â–ˆ| 128/128 [00:\n"
]
}
],
"source": [
"inputs, output, shapes, size_dict = ctg.utils.lattice_equation([5, 5, 5])\n",
"opt = ctg.HyperOptimizer(progbar=True, reconf_opts={}, minimize=\"combo-256\")\n",
"tree = opt.search(inputs, output, size_dict)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0632ef70",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"