Full Reference

class opt_einsum_fx.EfficientShapeProp(module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)

Like ShapeProp, traverses a graph Node-by-Node and records the shape and type of the result into each Node.

Except we treat ‘einsum’ as a special case. We don’t actually execute ‘einsum’ on tensors, since the einsums will typically not be optimized yet (ShapeProp is called before optimization), and inefficient summation order can create enormous intermediate tensors, which often creates needless out-of-memory errors.

So we override ‘run_node’ only for ‘einsums’. It’s straightforward to determine the shape of the result just from the output indices.

(The call to opt_einsum that will typically follow this, also doesn’t actually build the tensors during its exploration.)

run_node(n: torch.fx.node.Node) Any

Run a specific node n and return the result. Calls into placeholder, get_attr, call_function, call_method, call_module, or output depending on node.op

Parameters

n (Node) – The Node to execute

Returns

The result of executing n

Return type

Any

Note

Backwards-compatibility for this API is guaranteed.

opt_einsum_fx.fuse_einsums(graph: torch.fx.graph.Graph, in_place: bool = False) torch.fx.graph.Graph

Fuse einsums when possible.

When the output of one einsum is only used as an operand in another einsum, the two einsums can be fused into one.

Example

def fusable(x, y):
    z = torch.einsum("ij,jk->ik", x, y)
    return torch.einsum("ik,ij->i", z, x)

g = torch.fx.symbolic_trace(fusable)
print(fuse_einsums(g.graph).python_code(""))

gives:

import torch
def forward(self, x, y):
    einsum_2 = torch.functional.einsum('ib,bk,ij->i', x, y, x);  x = y = None
    return einsum_2
Parameters
  • graph – the graph to process.

  • in_place (bool, optional) – whether to process graph in place.

Returns

The graph with fused einsums.

opt_einsum_fx.fuse_scalars(graph: torch.fx.graph.Graph, in_place: bool = False) torch.fx.graph.Graph

Use the multilinearity of einsum to unify and remove constant scalars around einsums.

Parameters
  • graph – the graph to process.

  • in_place (bool, optional) – whether to process graph in place.

Returns

The graph with fused scalars.

opt_einsum_fx.jitable(obj: Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]) Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]

Convert some torch calls into their TorchScript signatures.

In place. Currently deals with tensordot and permute.

Parameters

obj – the fx.Graph or fx.GraphModule to process.

Returns

obj, modified in-place.

opt_einsum_fx.optimize_einsums(graph: torch.fx.graph.Graph, contract_kwargs: dict = {}) torch.fx.graph.Graph

Optimize einsums in a torch.fx.Graph using opt_einsum.

graph must have shape information such as that populated by torch.fx.passes.shape_prop.ShapeProp. The shapes are used for opt_einsum and the result is specific to the number of dimensions in the provided shapes opt_einsum:

…while it will work for a set of arrays with the same ranks as the original shapes but differing sizes, it might no longer be optimal.

See the opt_einsum documentation for more details.

Parameters
  • graph (fx.Graph) – the graph to optimize

  • contract_kwargs – extra keyword arguments for opt_einsum.contract_path.

Returns

An optimized fx.Graph.

opt_einsum_fx.optimize_einsums_full(model: Union[torch.nn.modules.module.Module, Callable, torch.fx.graph.Graph], example_inputs: tuple, contract_kwargs: dict = {}, tracer_class: type = <class 'torch.fx._symbolic_trace.Tracer'>) Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]

Optimize einsums in model for example_inputs.

All of the restrictions of torch.fx symbolic tracing apply.

Applies, in order, four optimizations:

  1. Scalar accumulation — use the multilinearity of einsum to collect all constant coefficients and divisors of operands and outputs

  2. Fusing einsums — gives greater flexibility to (3)

  3. Optimized contraction with opt_einsum.

  4. Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results

Parameters
  • model (torch.nn.Module or callable or fx.Graph) – the model, function, or fx.Graph to optimize.

  • example_inputs (tuple) – arguments to model whose shapes will determine the einsum optimizations.

  • tracer_class (type, optional) – the tracer class to use to turn model into an fx.Graph if it isn’t already an fx.GraphModule or fx.Graph.

Returns

An optimized fx.GraphModule, or if model is an fx.Graph, an optimized fx.Graph.