Extract contraction to matmuls only¶
%config InlineBackend.figure_formats = ['svg']
import autoray as ar
import cotengra as ctg
Create a random contraction and contraction tree:
inputs, output, shapes, size_dict = ctg.utils.rand_equation(
n=6,
reg=5,
n_out=1,
n_hyper_in=1,
n_hyper_out=1,
seed=42,
)
# square grid contraction:
# inputs, output, shapes, size_dict = ctg.utils.lattice_equation([3, 4])
tree = ctg.array_contract_tree(
inputs, output, shapes=shapes, optimize="optimal"
)
tree.get_hypergraph().plot()
(<Figure size 500x500 with 1 Axes>, <Axes: >)
The high level pairwise contractions can be shown as so:
tree.print_contractions()
(0) cost: 3.2e+01 widths: 2.0,5.0->4.0 type: einsum
inputs: {b}[i],a{b}[i]no->
output: {b}(ano)
(1) cost: 6.4e+01 widths: 4.0,5.0->5.0 type: einsum
inputs: {ba}[n]o,{ab}d[n]h->
output: {ba}o(dh)
(2) cost: 7.7e+02 widths: 5.0,8.6->7.6 type: einsum
inputs: {ba}o[dh],{ab}c[d]efg[h]->
output: {ba}o(cefg)
(3) cost: 1.5e+03 widths: 7.6,7.0->8.6 type: einsum
inputs: {ba}[o]cef[g],{ab}k[g]l[o]m->
output: {ba}cef(klm)
(4) cost: 1.2e+03 widths: 8.6,9.2->2.6 type: tensordot+perm
inputs: [b]a[cefklm],[b]j[cekflm]->
output: (j)a
Tracing the computational graph¶
cotengra
can also further break each einsum
call down into only (batch) matrix multiplies, reshapes and transposes. To extract these we need two things:
to trace the contraction with
autoray
lazy arrays.to use
implementation="cotengra"
to avoid using a backendseinsum
impl directly.
variables = [ar.lazy.Variable(shape, backend="numpy") for shape in shapes]
lz = tree.contract(
variables,
# make cotengra use its own implementation of einsum/tensordot
# which breaks things down to bmm / reshape / transpose
implementation="cotengra",
)
The remaining things are simply autoray related functionality. Visualize the operataions:
lz.plot_circuit()
(<Figure size 1049.51x1049.51 with 1 Axes>, <Axes: >)
Visualize as text / terminal:
lz.show()
0 transpose[3, 2]
1 ╰─matmul[2, 3]
2 ├─reshape[192, 3]
3 │ ╰─transpose[2, 2, 2, 3, 2, 2, 2, 3]
4 │ ╰─←[2, 3, 2, 2, 2, 3, 2, 2]
5 ╰─reshape[2, 192]
6 ╰─transpose[2, 2, 2, 2, 3, 2, 2, 2]
7 ╰─reshape[2, 2, 2, 2, 3, 2, 2, 2]
8 ╰─matmul[4, 12, 8]
9 ├─reshape[4, 4, 8]
10 │ ╰─transpose[2, 2, 2, 2, 2, 2, 2]
11 │ ╰─←[2, 2, 2, 2, 2, 2, 2]
12 ╰─reshape[4, 12, 4]
13 ╰─transpose[2, 2, 2, 2, 3, 2, 2]
14 ╰─reshape[2, 2, 2, 2, 2, 3, 2]
15 ╰─matmul[4, 2, 24]
16 ├─reshape[4, 4, 24]
17 │ ╰─transpose[2, 2, 2, 2, 2, 2, 3, 2]
18 │ ╰─←[2, 2, 2, 2, 2, 3, 2, 2]
19 ╰─reshape[4, 2, 4]
20 ╰─matmul[4, 2, 4]
21 ├─reshape[4, 2, 4]
22 │ ╰─transpose[2, 2, 2, 2, 2]
23 │ ╰─←[2, 2, 2, 2, 2]
24 ╰─reshape[4, 2, 2]
25 ╰─transpose[2, 2, 2, 2]
26 ╰─reshape[2, 2, 2, 2]
27 ╰─matmul[2, 1, 8]
28 ├─reshape[2, 1, 2]
29 │ ╰─←[2, 2]
30 ╰─reshape[2, 2, 8]
31 ╰─transpose[2, 2, 2, 2, 2]
32 ╰─←[2, 2, 2, 2, 2]
Some preparatory einsums might appear, only for removing initial diagonal indices and trivial indices etc.
getting the linear form of the contraction¶
# names of the inputs variables in the contraction
[f"x{id(v)}" for v in variables]
['x139996605978592',
'x139996605418560',
'x139996605418080',
'x139996605426720',
'x139996605418752',
'x139996605418272']
# a python source code of the whole contraction
print(lz.get_source())
x139996605419040 = transpose139997039584176(x139996605418272, (1, 2, 0, 3, 4,))
x139996603351968 = reshape139997039578096(x139996605419040, (2, 2, 8,))
del x139996605419040
x139996605426048 = reshape139997039578096(x139996605418560, (2, 1, 2,))
x139996603351488 = matmul139997611038160(x139996605426048, x139996603351968)
del x139996603351968
del x139996605426048
x139996603351872 = reshape139997039578096(x139996603351488, (2, 2, 2, 2,))
del x139996603351488
x139996603351776 = transpose139997039584176(x139996603351872, (0, 1, 3, 2,))
del x139996603351872
x139996603351680 = reshape139997039578096(x139996603351776, (4, 2, 2,))
del x139996603351776
x139996603351584 = transpose139997039584176(x139996605426720, (1, 0, 3, 2, 4,))
x139996603352160 = reshape139997039578096(x139996603351584, (4, 2, 4,))
del x139996603351584
x139996603352256 = matmul139997611038160(x139996603351680, x139996603352160)
del x139996603352160
del x139996603351680
x139996603352544 = reshape139997039578096(x139996603352256, (4, 2, 4,))
del x139996603352256
x139996603352640 = transpose139997039584176(x139996605978592, (1, 0, 3, 7, 2, 4, 5, 6,))
x139996603352736 = reshape139997039578096(x139996603352640, (4, 4, 24,))
del x139996603352640
x139996603352832 = matmul139997611038160(x139996603352544, x139996603352736)
del x139996603352736
del x139996603352544
x139996603352928 = reshape139997039578096(x139996603352832, (2, 2, 2, 2, 2, 3, 2,))
del x139996603352832
x139996603353216 = transpose139997039584176(x139996603352928, (0, 1, 3, 4, 5, 2, 6,))
del x139996603352928
x139996603353312 = reshape139997039578096(x139996603353216, (4, 12, 4,))
del x139996603353216
x139996603353504 = transpose139997039584176(x139996605418752, (1, 0, 5, 3, 2, 4, 6,))
x139996603353600 = reshape139997039578096(x139996603353504, (4, 4, 8,))
del x139996603353504
x139996603353696 = matmul139997611038160(x139996603353312, x139996603353600)
del x139996603353600
del x139996603353312
x139996603353792 = reshape139997039578096(x139996603353696, (2, 2, 2, 2, 3, 2, 2, 2,))
del x139996603353696
x139996603354176 = transpose139997039584176(x139996603353792, (1, 0, 2, 3, 4, 5, 6, 7,))
del x139996603353792
x139996603354272 = reshape139997039578096(x139996603354176, (2, 192,))
del x139996603354176
x139996603354368 = transpose139997039584176(x139996605418080, (0, 2, 3, 5, 4, 6, 7, 1,))
x139996603354464 = reshape139997039578096(x139996603354368, (192, 3,))
del x139996603354368
x139996603354560 = matmul139997611038160(x139996603354272, x139996603354464)
del x139996603354464
del x139996603354272
x139996603354656 = transpose139997039584176(x139996603354560, (1, 0,))
del x139996603354560
Or access the nodes programmatically:
for node in ar.lazy.ascend(lz):
print(f"fn: {node.fn}")
print(f"args: {node.args}")
print(f"kwargs: {node.kwargs}")
print()
fn: None
args: None
kwargs: None
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=None, shape=(2, 2, 2, 2, 2), backend='numpy')>, (1, 2, 0, 3, 4))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 2, 2), backend='numpy')>, (2, 2, 8))
kwargs: {}
fn: None
args: None
kwargs: None
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=None, shape=(2, 2), backend='numpy')>, (2, 1, 2))
kwargs: {}
fn: <built-in function matmul>
args: (<LazyArray(fn=reshape, shape=(2, 1, 2), backend='numpy')>, <LazyArray(fn=reshape, shape=(2, 2, 8), backend='numpy')>)
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=matmul, shape=(2, 1, 8), backend='numpy')>, (2, 2, 2, 2))
kwargs: {}
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=reshape, shape=(2, 2, 2, 2), backend='numpy')>, (0, 1, 3, 2))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 2), backend='numpy')>, (4, 2, 2))
kwargs: {}
fn: None
args: None
kwargs: None
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=None, shape=(2, 2, 2, 2, 2), backend='numpy')>, (1, 0, 3, 2, 4))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 2, 2), backend='numpy')>, (4, 2, 4))
kwargs: {}
fn: <built-in function matmul>
args: (<LazyArray(fn=reshape, shape=(4, 2, 2), backend='numpy')>, <LazyArray(fn=reshape, shape=(4, 2, 4), backend='numpy')>)
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=matmul, shape=(4, 2, 4), backend='numpy')>, (4, 2, 4))
kwargs: {}
fn: None
args: None
kwargs: None
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=None, shape=(2, 2, 2, 2, 2, 3, 2, 2), backend='numpy')>, (1, 0, 3, 7, 2, 4, 5, 6))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 2, 2, 2, 3, 2), backend='numpy')>, (4, 4, 24))
kwargs: {}
fn: <built-in function matmul>
args: (<LazyArray(fn=reshape, shape=(4, 2, 4), backend='numpy')>, <LazyArray(fn=reshape, shape=(4, 4, 24), backend='numpy')>)
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=matmul, shape=(4, 2, 24), backend='numpy')>, (2, 2, 2, 2, 2, 3, 2))
kwargs: {}
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=reshape, shape=(2, 2, 2, 2, 2, 3, 2), backend='numpy')>, (0, 1, 3, 4, 5, 2, 6))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 2, 3, 2, 2), backend='numpy')>, (4, 12, 4))
kwargs: {}
fn: None
args: None
kwargs: None
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=None, shape=(2, 2, 2, 2, 2, 2, 2), backend='numpy')>, (1, 0, 5, 3, 2, 4, 6))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 2, 2, 2, 2), backend='numpy')>, (4, 4, 8))
kwargs: {}
fn: <built-in function matmul>
args: (<LazyArray(fn=reshape, shape=(4, 12, 4), backend='numpy')>, <LazyArray(fn=reshape, shape=(4, 4, 8), backend='numpy')>)
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=matmul, shape=(4, 12, 8), backend='numpy')>, (2, 2, 2, 2, 3, 2, 2, 2))
kwargs: {}
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=reshape, shape=(2, 2, 2, 2, 3, 2, 2, 2), backend='numpy')>, (1, 0, 2, 3, 4, 5, 6, 7))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 2, 3, 2, 2, 2), backend='numpy')>, (2, 192))
kwargs: {}
fn: None
args: None
kwargs: None
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=None, shape=(2, 3, 2, 2, 2, 3, 2, 2), backend='numpy')>, (0, 2, 3, 5, 4, 6, 7, 1))
kwargs: {}
fn: <function reshape at 0x7f5399cea160>
args: (<LazyArray(fn=transpose, shape=(2, 2, 2, 3, 2, 2, 2, 3), backend='numpy')>, (192, 3))
kwargs: {}
fn: <built-in function matmul>
args: (<LazyArray(fn=reshape, shape=(2, 192), backend='numpy')>, <LazyArray(fn=reshape, shape=(192, 3), backend='numpy')>)
kwargs: {}
fn: <function transpose at 0x7f5399cea7a0>
args: (<LazyArray(fn=matmul, shape=(2, 3), backend='numpy')>, (1, 0))
kwargs: {}