Full Reference
- class opt_einsum_fx.EfficientShapeProp(module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)
Like ShapeProp, traverses a graph Node-by-Node and records the shape and type of the result into each Node.
Except we treat ‘einsum’ as a special case. We don’t actually execute ‘einsum’ on tensors, since the einsums will typically not be optimized yet (ShapeProp is called before optimization), and inefficient summation order can create enormous intermediate tensors, which often creates needless out-of-memory errors.
So we override ‘run_node’ only for ‘einsums’. It’s straightforward to determine the shape of the result just from the output indices.
(The call to opt_einsum that will typically follow this, also doesn’t actually build the tensors during its exploration.)
- run_node(n: torch.fx.node.Node) Any
Run a specific node
n
and return the result. Calls into placeholder, get_attr, call_function, call_method, call_module, or output depending onnode.op
- Parameters
n (Node) – The Node to execute
- Returns
The result of executing
n
- Return type
Any
Note
Backwards-compatibility for this API is guaranteed.
- opt_einsum_fx.fuse_einsums(graph: torch.fx.graph.Graph, in_place: bool = False) torch.fx.graph.Graph
Fuse einsums when possible.
When the output of one einsum is only used as an operand in another einsum, the two einsums can be fused into one.
Example
def fusable(x, y): z = torch.einsum("ij,jk->ik", x, y) return torch.einsum("ik,ij->i", z, x) g = torch.fx.symbolic_trace(fusable) print(fuse_einsums(g.graph).python_code(""))
gives:
import torch def forward(self, x, y): einsum_2 = torch.functional.einsum('ib,bk,ij->i', x, y, x); x = y = None return einsum_2
- Parameters
graph – the graph to process.
in_place (bool, optional) – whether to process
graph
in place.
- Returns
The graph with fused einsums.
- opt_einsum_fx.fuse_scalars(graph: torch.fx.graph.Graph, in_place: bool = False) torch.fx.graph.Graph
Use the multilinearity of einsum to unify and remove constant scalars around einsums.
- Parameters
graph – the graph to process.
in_place (bool, optional) – whether to process
graph
in place.
- Returns
The graph with fused scalars.
- opt_einsum_fx.jitable(obj: Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]) Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]
Convert some torch calls into their TorchScript signatures.
In place. Currently deals with
tensordot
andpermute
.- Parameters
obj – the
fx.Graph
orfx.GraphModule
to process.- Returns
obj
, modified in-place.
- opt_einsum_fx.optimize_einsums(graph: torch.fx.graph.Graph, contract_kwargs: dict = {}) torch.fx.graph.Graph
Optimize einsums in a
torch.fx.Graph
usingopt_einsum
.graph
must have shape information such as that populated bytorch.fx.passes.shape_prop.ShapeProp
. The shapes are used foropt_einsum
and the result is specific to the number of dimensions in the provided shapesopt_einsum
:…while it will work for a set of arrays with the same ranks as the original shapes but differing sizes, it might no longer be optimal.
See the
opt_einsum
documentation for more details.- Parameters
graph (fx.Graph) – the graph to optimize
contract_kwargs – extra keyword arguments for
opt_einsum.contract_path
.
- Returns
An optimized
fx.Graph
.
- opt_einsum_fx.optimize_einsums_full(model: Union[torch.nn.modules.module.Module, Callable, torch.fx.graph.Graph], example_inputs: tuple, contract_kwargs: dict = {}, tracer_class: type = <class 'torch.fx._symbolic_trace.Tracer'>) Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]
Optimize einsums in
model
forexample_inputs
.All of the restrictions of
torch.fx
symbolic tracing apply.Applies, in order, four optimizations:
Scalar accumulation — use the multilinearity of einsum to collect all constant coefficients and divisors of operands and outputs
Fusing einsums — gives greater flexibility to (3)
Optimized contraction with
opt_einsum
.Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results
- Parameters
model (torch.nn.Module or callable or fx.Graph) – the model, function, or
fx.Graph
to optimize.example_inputs (tuple) – arguments to
model
whose shapes will determine the einsum optimizations.tracer_class (type, optional) – the tracer class to use to turn
model
into anfx.Graph
if it isn’t already anfx.GraphModule
orfx.Graph
.
- Returns
An optimized
fx.GraphModule
, or ifmodel
is anfx.Graph
, an optimizedfx.Graph
.