import time
import warnings
import itertools
from distutils.version import LooseVersion
from collections import OrderedDict
from .proto.attr_value_pb2 import AttrValue
from .proto.graph_pb2 import GraphDef
from .proto.node_def_pb2 import NodeDef
from .proto.step_stats_pb2 import RunMetadata, StepStats, DeviceStepStats, NodeExecStats, AllocatorMemoryUsed
from .proto.tensor_shape_pb2 import TensorShapeProto
from .proto.versions_pb2 import VersionDef
from .proto_graph import Node_proto
methods_OP = ['attributeNames', 'hasMultipleOutputs', 'hasUses', 'inputs',
'kind', 'outputs', 'outputsSize', 'scopeName']
methods_IO = ['node', 'offset', 'uniqueName'] # 'unique' <int> , 'type' <Tensor<class 'torch._C.Type'>>
class Node_base(object):
def __init__(self, uniqueName=None, inputs=None, scope=None, tensorSize=None, op_type='UnSpecified', attributes=''):
self.uniqueName = uniqueName
self.inputs = inputs
self.tensorSize = tensorSize
self.kind = op_type
self.attributes = attributes
if scope is not None:
self.scope = scope
def __repr__(self):
repr = []
repr.append(str(type(self)))
for m in dir(self):
if '__' not in m:
repr.append(m + ': ' + str(getattr(self, m)) + str(type(getattr(self, m))))
return '\n'.join(repr) + '\n\n'
class Node_py(Node_base):
def __init__(self, Node_cpp, valid_mothods):
super(Node_py, self).__init__(Node_py)
self.valid_mothods = valid_mothods[:]
self.inputs = []
for m in self.valid_mothods:
if m == 'inputs' or m == 'outputs':
list_of_node = list(getattr(Node_cpp, m)())
io_uniqueName_list = []
io_tensorSize_list = []
for n in list_of_node:
io_uniqueName_list.append(n.uniqueName())
if n.type().kind() in ['DynamicType', 'ListType']: # segfault
io_tensorSize_list.append(None)
else:
io_tensorSize_list.append(n.type().sizes())
setattr(self, m, io_uniqueName_list)
setattr(self, m + 'TensorSize', io_tensorSize_list)
else:
setattr(self, m, getattr(Node_cpp, m)())
class Node_py_IO(Node_py):
def __init__(self, Node_cpp, input_or_output=None):
super(Node_py_IO, self).__init__(Node_cpp, methods_IO)
self.tensorSize = Node_cpp.type().sizes()
self.kind = 'Parameter'
if input_or_output:
self.input_or_output = input_or_output
class Node_py_OP(Node_py):
def __init__(self, Node_cpp):
super(Node_py_OP, self).__init__(Node_cpp, methods_OP)
self.attributes = str({k: Node_cpp[k] for k in Node_cpp.attributeNames()}).replace("'", ' ')
self.kind = Node_cpp.kind()
class Graph_py(object):
def __init__(self):
self.nodes_OP = []
self.nodes_IO = OrderedDict()
self.uniqueNameToScopedName = {}
def append(self, x):
if type(x) == Node_py_IO:
self.nodes_IO[x.uniqueName] = x
if type(x) == Node_py_OP:
self.nodes_OP.append(x)
for node_output, outputSize in zip(x.outputs, x.outputsTensorSize):
self.nodes_IO[node_output] = Node_base(node_output,
x.inputs,
x.scopeName,
outputSize,
op_type=x.kind,
attributes=x.attributes)
def printall(self):
print('all nodes')
for node in self.nodes_OP:
print(node)
for key in self.nodes_IO:
print(self.nodes_IO[key])
def populate_namespace_from_OP_to_IO(self):
for node in self.nodes_OP:
for input_node_id in node.inputs:
self.uniqueNameToScopedName[input_node_id] = node.scopeName + '/' + input_node_id
for key, node in self.nodes_IO.items():
if type(node) == Node_base:
self.uniqueNameToScopedName[key] = node.scope + '/' + node.uniqueName
if hasattr(node, 'input_or_output'):
self.uniqueNameToScopedName[key] = node.input_or_output + '/' + node.uniqueName
# replace name
# print(self.uniqueNameToScopedName)
for key, node in self.nodes_IO.items():
self.nodes_IO[key].inputs = [self.uniqueNameToScopedName[node_input_id] for node_input_id in node.inputs]
if node.uniqueName in self.uniqueNameToScopedName:
self.nodes_IO[key].uniqueName = self.uniqueNameToScopedName[node.uniqueName]
def to_proto(self):
import numpy as np
nodes = []
node_stats = []
for v in self.nodes_IO.values():
nodes.append(Node_proto(v.uniqueName,
input=v.inputs,
outputsize=v.tensorSize,
op=v.kind,
attributes=v.attributes))
if v.tensorSize and len(v.tensorSize) > 0: # assume data is float32, only parameter is counted
node_stats.append(NodeExecStats(node_name=v.uniqueName,
all_start_micros=int(time.time() * 1e7),
all_end_rel_micros=42,
memory=[AllocatorMemoryUsed(allocator_name="cpu",
total_bytes=np.prod(v.tensorSize) * 4)]))
return nodes, node_stats
# one argument: 'hasAttribute', 'hasAttributes',
def parse(graph, args=None, omit_useless_nodes=True):
import torch
n_inputs = len(args) # not sure...
scope = {}
nodes_py = Graph_py()
for i, node in enumerate(graph.inputs()):
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue
if i < n_inputs:
nodes_py.append(Node_py_IO(node, 'input'))
else:
nodes_py.append(Node_py_IO(node)) # parameter
for node in graph.nodes():
nodes_py.append(Node_py_OP(node))
for node in graph.outputs(): # must place last.
Node_py_IO(node, 'output')
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False, omit_useless_nodes=True):
import torch
from torch.onnx.utils import OperatorExportTypes
from torch.onnx import utils
def _optimize_trace(trace, operator_export_type):
trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))
def _optimize_graph(graph, operator_export_type):
# torch._C._jit_pass_remove_inplace_ops(graph)
# we record now record some ops like ones/zeros
# into a trace where we previously recorded constants
# use constant prop to maintain our current level of onnx support
# without implementing symbolics for all of them
torch._C._jit_pass_constant_propagation(graph)
torch.onnx.utils._split_tensor_list_constants(graph, graph)
# run dce to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_canonicalize_ops(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lint(graph)
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
torch._C._jit_pass_prepare_division_for_onnx(graph)
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
# onnx does not support tuples, so try to remove them
torch._C._jit_pass_lower_all_tuples(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lint(graph)
if operator_export_type != OperatorExportTypes.RAW:
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
torch._C._jit_pass_lint(graph)
# torch._C._jit_pass_onnx_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_fixup_onnx_loops(graph)
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
return graph
assert LooseVersion(torch.__version__) >= LooseVersion("1.0.0")
with torch.onnx.set_training(model, False):
try:
trace, _ = torch.jit.get_trace_graph(model, args)
except RuntimeError:
print('Error occurs, No graph saved')
_ = model(args) # don't catch, just print the error message
print("Checking if it's onnx problem...")
try:
import tempfile
torch.onnx.export(
model, args, tempfile.TemporaryFile(), verbose=True)
except RuntimeError:
print("Your model fails onnx too, please report to onnx team")
return GraphDef(versions=VersionDef(producer=22))
_optimize_trace(trace, torch.onnx.utils.OperatorExportTypes.ONNX)
graph = trace.graph()
if verbose:
print(graph)
list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes)
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0",
node_stats=node_stats)]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats