Visualization#
Various aspects of the contraction process can be visualized using the following functions.
%config InlineBackend.figure_formats = ['svg']
import cotengra as ctg
# generate an equation representing a 2D lattice
inputs, output, shapes, size_dict = ctg.utils.lattice_equation([5, 6])
Hypergraph visualization#
The first visualization we can do is to visualize the hypergraph corresponding
to the geometry of the tensor network or equation with
HyperGraph.plot
:
hg = ctg.get_hypergraph(inputs, output, size_dict)
hg.plot(draw_edge_labels=True)
(<Figure size 500x500 with 1 Axes>, <Axes: >)
The default coloring of the nodes shows a simple centrality measure.
Note
Hyper edges are shown as a zero size vertex connecting the tensors they appear on, like a COPY-tensor.
ctg.get_hypergraph(['ax', 'bx', 'cx']).plot(draw_edge_labels=True)
(<Figure size 500x500 with 1 Axes>, <Axes: >)
Optimizer trials visualization#
If we run an hyper optimizer, we can visualize how the scores progress with trials using the
HyperOptimizer.plot_trials
method:
opt = ctg.HyperOptimizer(methods=['greedy', 'kahypar', 'labels'], progbar=True)
# run it and generate a tree
tree = opt.search(inputs, output, size_dict)
log2[SIZE]: 6.00 log10[FLOPs]: 3.46: 100%|██████████| 128/128 [00:02<00:00, 63.13it/s]
By default the y-axis is the objective score, but you can specify e.g. 'flops'
’:
opt.plot_trials('flops')
(<Figure size 800x300 with 1 Axes>,
<Axes: xlabel='TRIAL', ylabel='log10[FLOPS]'>)
We can also plot the distribution of contraction costs against contraction widths using
the HyperOptimizer.plot_scatter
method:
opt.plot_scatter(x='size', y='flops')
(<Figure size 500x500 with 1 Axes>,
<Axes: xlabel='log2[SIZE]', ylabel='log10[FLOPS]'>)
Tree visualizations#
The following visualization functions are available for inspecting a single,
complete ContractionTree
once generated.
They mostly wrap plot_tree
, where you can see
most of the extra options.
Contractions#
tree.plot_contractions
gives you an overview of the memory and costs throughout the contraction:
tree.plot_contractions()
(<Figure size 800x300 with 2 Axes>,
<Axes: xlabel='contraction', ylabel='$\\log_2[SIZE]$'>)
Here, peak
is the memory required for all intermediates to be stored at once,
write
is the size of the new intermedite tensor, the max of which is the
contraction width. cost
is the scalar operations of each contraction.
Tent#
The most general purpose visualization for the structure of a
ContractionTree
is
the ContractionTree.plot_tent
method.
This plots the input network (in grey) at the bottom, and the contraction tree
intermediates laid out above. The width and color of the tree edges denote the
intermediate tensor widths, and the size and color of the tree nodes denote the
FLOPs required to contract each intermediate tensor:
tree.plot_tent()
(<Figure size 500x500 with 3 Axes>, <Axes: >)
If you supply order=True
then the intermediate nodes will be in the exact vertical order than they would be performed:
tree.plot_tent(order=True)
(<Figure size 500x500 with 3 Axes>, <Axes: >)
Note
If you have sliced indices, these will appear as dashed lines in the input graph.
Ring#
Another option is the ContractionTree.plot_ring
method which lays out the input network on a ring, with the contraction
tree intermediates laid out towards the center. The more arcs cross between
branches the more expensive that contraction. This can be useful for
inspecting how many ‘spines’ a contraction has or how balanced it is:
tree.plot_ring()
(<Figure size 500x500 with 3 Axes>, <Axes: >)
Rubberband#
For small and close to planar graphs, an alternative visualization is the
ContractionTree.plot_rubberband
method.
method using quimb
. Here, nodes of the input graph are hierarchically grouped into bands
according to the contraction tree. The order of contraction is represented by
the colors:
tree.plot_rubberband()
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
(<Figure size 500x500 with 1 Axes>, <Axes: >)
All of the above methods can be pretty extensively customized, including by
supplying custom colormaps. They also return (fig, ax)
for further
customization or embedding in other plots.
inputs, output, shapes, size_dict = ctg.utils.lattice_equation([5, 5, 5])
opt = ctg.HyperOptimizer(progbar=True, reconf_opts={}, minimize='combo-256')
tree = opt.search(inputs, output, size_dict)
log2[SIZE]: 31.00 log10[FLOPs]: 11.54: 100%|██████████| 128/128 [00:21<00:00, 6.08it/s]
tree.plot_tent(
raw_edge_alpha=0.0,
edge_colormap='Greens',
node_colormap='Greens',
edge_scale=2,
node_scale=0.5,
colorbars=False,
tree_root_height=-1.0,
figsize=(10, 10),
)
(<Figure size 1000x1000 with 1 Axes>, <Axes: >)