Full Reference
- class opt_einsum_fx.EfficientShapeProp(module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)
Like ShapeProp, traverses a graph Node-by-Node and records the shape and type of the result into each Node.
Except we treat ‘einsum’ as a special case. We don’t actually execute ‘einsum’ on tensors, since the einsums will typically not be optimized yet (ShapeProp is called before optimization), and inefficient summation order can create enormous intermediate tensors, which often creates needless out-of-memory errors.
So we override ‘run_node’ only for ‘einsums’. It’s straightforward to determine the shape of the result just from the output indices.
(The call to opt_einsum that will typically follow this, also doesn’t actually build the tensors during its exploration.)
- run_node(n: torch.fx.node.Node) Any
Run a specific node
nand return the result. Calls into placeholder, get_attr, call_function, call_method, call_module, or output depending onnode.op- Parameters
n (Node) – The Node to execute
- Returns
The result of executing
n- Return type
Any
Note
Backwards-compatibility for this API is guaranteed.
- opt_einsum_fx.fuse_einsums(graph: torch.fx.graph.Graph, in_place: bool = False) torch.fx.graph.Graph
Fuse einsums when possible.
When the output of one einsum is only used as an operand in another einsum, the two einsums can be fused into one.
Example
def fusable(x, y): z = torch.einsum("ij,jk->ik", x, y) return torch.einsum("ik,ij->i", z, x) g = torch.fx.symbolic_trace(fusable) print(fuse_einsums(g.graph).python_code(""))
gives:
import torch def forward(self, x, y): einsum_2 = torch.functional.einsum('ib,bk,ij->i', x, y, x); x = y = None return einsum_2
- Parameters
graph – the graph to process.
in_place (bool, optional) – whether to process
graphin place.
- Returns
The graph with fused einsums.
- opt_einsum_fx.fuse_scalars(graph: torch.fx.graph.Graph, in_place: bool = False) torch.fx.graph.Graph
Use the multilinearity of einsum to unify and remove constant scalars around einsums.
- Parameters
graph – the graph to process.
in_place (bool, optional) – whether to process
graphin place.
- Returns
The graph with fused scalars.
- opt_einsum_fx.jitable(obj: Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]) Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]
Convert some torch calls into their TorchScript signatures.
In place. Currently deals with
tensordotandpermute.- Parameters
obj – the
fx.Graphorfx.GraphModuleto process.- Returns
obj, modified in-place.
- opt_einsum_fx.optimize_einsums(graph: torch.fx.graph.Graph, contract_kwargs: dict = {}) torch.fx.graph.Graph
Optimize einsums in a
torch.fx.Graphusingopt_einsum.graphmust have shape information such as that populated bytorch.fx.passes.shape_prop.ShapeProp. The shapes are used foropt_einsumand the result is specific to the number of dimensions in the provided shapesopt_einsum:…while it will work for a set of arrays with the same ranks as the original shapes but differing sizes, it might no longer be optimal.
See the
opt_einsumdocumentation for more details.- Parameters
graph (fx.Graph) – the graph to optimize
contract_kwargs – extra keyword arguments for
opt_einsum.contract_path.
- Returns
An optimized
fx.Graph.
- opt_einsum_fx.optimize_einsums_full(model: Union[torch.nn.modules.module.Module, Callable, torch.fx.graph.Graph], example_inputs: tuple, contract_kwargs: dict = {}, tracer_class: type = <class 'torch.fx._symbolic_trace.Tracer'>) Union[torch.fx.graph_module.GraphModule, torch.fx.graph.Graph]
Optimize einsums in
modelforexample_inputs.All of the restrictions of
torch.fxsymbolic tracing apply.Applies, in order, four optimizations:
Scalar accumulation — use the multilinearity of einsum to collect all constant coefficients and divisors of operands and outputs
Fusing einsums — gives greater flexibility to (3)
Optimized contraction with
opt_einsum.Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results
- Parameters
model (torch.nn.Module or callable or fx.Graph) – the model, function, or
fx.Graphto optimize.example_inputs (tuple) – arguments to
modelwhose shapes will determine the einsum optimizations.tracer_class (type, optional) – the tracer class to use to turn
modelinto anfx.Graphif it isn’t already anfx.GraphModuleorfx.Graph.
- Returns
An optimized
fx.GraphModule, or ifmodelis anfx.Graph, an optimizedfx.Graph.