Visualization#

Various aspects of the contraction process can be visualized using the following functions.

%config InlineBackend.figure_formats = ['svg']
import cotengra as ctg

# generate an equation representing a 2D lattice
inputs, output, shapes, size_dict = ctg.utils.lattice_equation([5, 6])

Hypergraph visualization#

The first visualization we can do is to visualize the hypergraph corresponding to the geometry of the tensor network or equation with HyperGraph.plot:

hg = ctg.get_hypergraph(inputs, output, size_dict)
hg.plot(draw_edge_labels=True)
_images/a2230a8688f9b982e35e82e88a119c7fd6bd08d7245b6a814944b2aeeddb05ed.svg

The default coloring of the nodes shows a simple centrality measure.

Note

Hyper edges are shown as a zero size vertex connecting the tensors they appear on, like a COPY-tensor.

ctg.get_hypergraph(['ax', 'bx', 'cx']).plot(draw_edge_labels=True)
_images/be59e3dc200d6aa46238c128fec8d460d75e7e7b2e475be2a6f63069d7de15a0.svg

Optimizer trials visualization#

If we run an hyper optimizer, we can visualize how the scores progress with trials using the HyperOptimizer.plot_trials method:

opt = ctg.HyperOptimizer(methods=['greedy', 'kahypar', 'labels'], progbar=True)

# run it and generate a tree
tree = opt.search(inputs, output, size_dict)
log2[SIZE]: 6.00 log10[FLOPs]: 3.80: 100%|██████████| 128/128 [00:02<00:00, 50.78it/s]

By default the y-axis is the objective score, but you can specify e.g. 'flops'’:

opt.plot_trials('flops')
_images/44c87112ac978c89f43375a8b07d2617497fad8b1a55676437e92cede1b98bf5.svg

We can also plot the distribution of contraction costs against contraction widths using the HyperOptimizer.plot_scatter method:

opt.plot_scatter(x='size', y='flops')
_images/889e46ef99beaaa16a5f62250fe0e2d23bb87089403c5907d926450ffb2a0097.svg

Tree visualizations#

The following visualization functions are available for inspecting a single, complete ContractionTree once generated. They mostly wrap plot_tree, where you can see most of the extra options.

Tent#

The most general purpose visualization for a ContractionTree is the ContractionTree.plot_tent method. This plots the input network (in grey) at the bottom, and the contraction tree intermediates laid out above. The width and color of the tree edges denote the intermediate tensor widths, and the size and color of the tree nodes denote the FLOPs required to contract each intermediate tensor:

tree.plot_tent()
_images/b3a37ab319840697c88bb1f0d01723b07c1d02d3fd5f1096a77628fcc934d3c1.svg

If you supply order=True then the intermediate nodes will be in the exact vertical order than they would be performed:

tree.plot_tent(order=True)
_images/68037adbba7aa1deaae0d62eabfc1379cce4c4a8cb0babd80f7d4b6c59d4678e.svg

Note

If you have sliced indices, these will appear as dashed lines in the input graph.

Ring#

Another option is the ContractionTree.plot_ring method which lays out the input network on a ring, with the contraction tree intermediates laid out towards the center. The more arcs cross between branches the more expensive that contraction. This can be useful for inspecting how many ‘spines’ a contraction has or how balanced it is:

tree.plot_ring()
_images/58117c55d31dd9b9e647f3a625b99397e18c119eb3be223cb998959d6a243b69.svg

Rubberband#

For small and close to planar graphs, an alternative visualization is the ContractionTree.plot_rubberband method. method from the hypernetx package. Here, nodes of the input graph are hierarchically grouped into bands according to the contraction tree. The order of contraction is represented by the colors:

tree.plot_rubberband()
_images/e060d431bbea0e4945f7ad7626fb30b1f9c96275bf8a6f0b09dbc729d85c61d6.svg

All of the above methods can be pretty extensively customized, including by supplying custom colormaps. Most also take a return_fig kwarg which can be used to return the matplotlib figure for more customization.

inputs, output, shapes, size_dict = ctg.utils.lattice_equation([5, 5, 5])
opt = ctg.HyperOptimizer(progbar=True, reconf_opts={}, minimize='combo-256')
tree = opt.search(inputs, output, size_dict)
log2[SIZE]: 30.00 log10[FLOPs]: 12.53: 100%|██████████| 128/128 [02:13<00:00,  1.04s/it]
tree.plot_tent(
    raw_edge_alpha=0.0,
    edge_colormap='Greens',
    node_colormap='Greens',
    edge_scale=2,
    node_scale=0.5,
    colorbars=False,
    tree_root_height=-1.0,
    figsize=(10, 10),
)
_images/a0e1cbe91decd38273185c3b5ec4467d676993927f83dcdf39d74d51be53883b.svg