Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/tensorboardX/proto_graph.py

from .proto.graph_pb2 import GraphDef
from .proto.node_def_pb2 import NodeDef
from .proto.versions_pb2 import VersionDef
from .proto.attr_value_pb2 import AttrValue
from .proto.tensor_shape_pb2 import TensorShapeProto

from collections import defaultdict

# nodes.append(
#     NodeDef(name=node['name'], op=node['op'], input=node['inputs'],
#             attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8')),
#                   '_output_shapes': AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))}))


def AttrValue_proto(dtype,
                    shape,
                    s,
                    ):
    attr = {}

    if s is not None:
        attr['attr'] = AttrValue(s=s.encode(encoding='utf_8'))

    if shape is not None:
        shapeproto = TensorShape_proto(shape)
        attr['_output_shapes'] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))
    return attr


def TensorShape_proto(outputsize):
    return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize])


def Node_proto(name,
               op='UnSpecified',
               input=[],
               dtype=None,
               shape=None,  # type: tuple
               outputsize=None,
               attributes=''
               ):
    if not isinstance(input, list):
        input = [input]
    return NodeDef(
        name=name.encode(encoding='utf_8'),
        op=op,
        input=input,
        attr=AttrValue_proto(dtype, outputsize, attributes)
    )
Back to Directory File Manager