【TVM源码学习笔记】2 模型导入from_onnx

2023-11-09

在前文模型加载时,使用relay.frontend.from_onnx(onnx_model, shape_dict)是将onnx模型转换为TVM可以识别的Graph IR。要理解这一流程,需要对onnx模型定义有基础的了解。

1.onnx模型文件简介

onnx模型的数据定义参见(onnx/onnx.proto at main · onnx/onnx · GitHub)onnx.proto文件。onnx模型的数据类型有如下几种:

        ModelProto:加载了一个onnx模型之后获得的就是一个ModelProto,它包含了一些版本信息,生产者信息和一个GraphProto。

        GraphProto:在GraphProto里面又包含了四个repeated数组,它们分别是node(NodeProto类型),input(ValueInfoProto类型),output(ValueInfoProto类型)和initializer(TensorProto类型);

        NodeProto:网络节点数据类型。GraphProto使用一个NodeProto数组记录网络的所有节点。每个节点都会有两个string类型的数组:input和output,表示当前节点的输入数据的源节点和输出数据的目的节点。通过input和output的指向关系来构建出一个深度学习模型的网络结构

        ValueInfoProto:GraphProto中有两个ValueInfoProto类型的数组:input和output,分别存放网络的输入输出;

        TensorProto:张量类型。GraphProto中定义了一个该类型的数组:initializer,用于存放网络的常量输入,也就是模型的权重参数。数组中的所有张量必须是有名字的。并且这些张量名字也出现在前述的input数组中。

        AttributeProto:网络节点(NodeProto类型)的属性类型,用来描述该节点的属性信息,比如Conv节点或者说卷积层的属性包含group,pad,strides等等。

        这里要注意一下, GraphProto中的input数组不仅包含我们一般理解中的图片输入的那个节点,还包含了模型中所有的权重。例如,Conv层里面的W权重实体是保存在initializer中的,那么相应的会有一个同名的元素在input中。其背后的逻辑应该是把权重也看成模型的输入,并通过initializer中的权重实体来对这个输入做初始化。 

initializer和input中都有网络权重,那么它们有什么区别呢?initializer是TensorProto类型数组,记录了数据张量的名字、类型、shape等信息。

message TensorProto {
  enum DataType {
    UNDEFINED = 0;
    // Basic types.
    FLOAT = 1;   // float
    UINT8 = 2;   // uint8_t
    INT8 = 3;    // int8_t
    UINT16 = 4;  // uint16_t
    INT16 = 5;   // int16_t
    INT32 = 6;   // int32_t
    INT64 = 7;   // int64_t
    STRING = 8;  // string
    BOOL = 9;    // bool
    FLOAT16 = 10;

    DOUBLE = 11;
    UINT32 = 12;
    UINT64 = 13;
    COMPLEX64 = 14;     // complex with float32 real and imaginary components
    COMPLEX128 = 15;    // complex with float64 real and imaginary components
    BFLOAT16 = 16;
  }

  repeated int64 dims = 1;
  optional int32 data_type = 2;
  
  message Segment {
    optional int64 begin = 1;
    optional int64 end = 2;
  }
  optional Segment segment = 3;
  repeated float float_data = 4 [packed = true];
  repeated int32 int32_data = 5 [packed = true];
  repeated bytes string_data = 6;
  repeated int64 int64_data = 7 [packed = true];
  optional string name = 8; // namespace Value
  optional string doc_string = 12;
  optional bytes raw_data = 9;
  repeated StringStringEntryProto external_data = 13;

  enum DataLocation {
    DEFAULT = 0;
    EXTERNAL = 1;
  }

  optional DataLocation data_location = 14;
  repeated double double_data = 10 [packed = true];
  repeated uint64 uint64_data = 11 [packed = true];
}

而input数组是ValueInfoProto类型,只记录了张量的名字:

message ValueInfoProto {
  optional string name = 1;     // namespace Value
  optional TypeProto type = 2;
  optional string doc_string = 3;
}

2. from_onnx流程

relay.frontend.from_onnx就是读入onnx模型,按照前述onnx模型数据结构,解析模型的initializer、input、nodes和output,将算子转换为tvm relay ir算子和表达式,最终得到整个模型的tvm relay IRModule。from_onnx定义在python/tvm/relay/frontend/onnx.py:

def from_onnx(
    model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None
):
    global ONNX_DEFAULT_CONFIGS
    if convert_config is not None:
        ONNX_DEFAULT_CONFIGS.update(convert_config)

    try:
        import onnx

        if hasattr(onnx.checker, "check_model"):
            # try use onnx's own model checker before converting any model
            try:
                onnx.checker.check_model(model)
            except Exception as e:  # pylint: disable=c-extension-no-member, broad-except
                # the checker is a bit violent about errors, so simply print warnings here
                warnings.warn(str(e))
    except ImportError:
        pass
    g = GraphProto(shape, dtype, freeze_params)
    graph = model.graph

    try:
        opset_in_model = model.opset_import[0].version if model.opset_import else 1
    except AttributeError:
        opset_in_model = 1

    if opset is None:
        opset = opset_in_model
    elif opset < opset_in_model:
        warnings.warn(
            ""
            f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "
            f"That might cause model conversion errors."
        )

    # Use the graph proto as a scope so that ops can access other nodes if needed.
    with g:
        mod, params = g.from_onnx(graph, opset)
    return mod, params

我们可以只用关注下面几行代码

...

g = GraphProto(shape, dtype, freeze_params)
graph = model.graph

...

with g:
        mod, params = g.from_onnx(graph, opset)
    return mod, params

这里生成一个TVM的GraphProto实例g, 然后将传入的onnx模型graph传入GraphProto的from_onnx方法。

GraphProto.from_onnx的参数有onnx模型的TVM GraphProto实例、版本信息、和转换结果返回方式配置:如果设置为true,则只打印onnx模型转为tvm后的表示;默认为false,将返回onnx模型的TVM中间表示数据mod(tvm.IRModule类型)和参数params。

    def from_onnx(self, graph, opset, get_output_expr=False):
        """Construct Relay expression from ONNX graph.

        Onnx graph is a python protobuf object.
        The companion parameters will be handled automatically.
        However, the input names from onnx graph is vague, mixing inputs and
        network weights/bias such as "1", "2"...
        For convenience, we rename the `real` input names to "input_0",
        "input_1"... And renaming parameters to "param_0", "param_1"...

        Parameters
        ----------
        graph : onnx protobuf object
            The loaded onnx graph

        opset : opset version

        get_output_expr: bool
            If set to true, this conversion will return each output expression rather
            than a packaged module. This can be useful when converting subgraphs to
            relay.

        Returns
        -------
        mod : tvm.IRModule
            The returned relay module

        params : dict
            A dict of name: tvm.nd.array pairs, used as pretrained weights
        """
        self.opset = opset
        self._parse_graph_initializers(graph)
        self._parse_graph_input(graph)
        self._check_user_inputs_in_outermost_graph_scope()
        self._check_for_unsupported_ops(graph)
        self._construct_nodes(graph)

        # now return the outputs
        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
        # If requested, directly return the converted expressions.
        if get_output_expr:
            return outputs
        ## Maintain the order of inputs and parameters from the ONNX graph, but only include
        ## those parameters that are needed to execute the relay graph
        free_vars = analysis.free_vars(outputs)
        nodes = {v: k for k, v in self._nodes.items()}
        free_vars = [nodes[var] for var in free_vars]
        for i_name in self._params:
            if i_name in free_vars and i_name not in self._inputs:
                self._inputs[i_name] = self._nodes[i_name]
        # Create a function from our output expression and all input variables.
        func = _function.Function([v for k, v in self._inputs.items()], outputs)
        return IRModule.from_expr(func), self._params

2.1 解析onnx模型权重

from_onnx中,首先调用_parse_graph_initializers从onnx模型的initializer数据段中解析转换网络权重数据:

    def _parse_graph_initializers(self, graph):
        """Parse network inputs to relay, aka parameters."""
        # onnx的initializer存放了模型的每个节点的权重参数
        for init_tensor in graph.initializer:
            # 这里的init_tensor.name权重张量的名字.例如网络中某个卷积权重参数W名字为186,偏置参数B名
		    # 字为188, graph.initializer里面就会有两个对应的节点,name分别为186和188

            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            # 将权重矩阵转换为tvm的nD-array
            array = self._parse_array(init_tensor)
            if self._freeze_params:
                # 如果设置了feeze_params参数,则看做常量节点(针对不定参数模型的优化?)
                self._nodes[init_tensor.name] = _expr.const(array)
            else:
                # 将解析的参数记录到参数表中
                self._params[init_tensor.name] = array
                # 在节点表中新增一个节点记录该参数
                self._nodes[init_tensor.name] = new_var(
                    init_tensor.name,
                    shape=self._params[init_tensor.name].shape,
                    dtype=self._params[init_tensor.name].dtype,
                )

如前onnx模型结构介绍所述,initializer中的张量必须是有名字的,所以这里对每个元素都判断是否有名字,如果名字字段为空,则认为网络不合法。

_parse_array是将onnx的tensor转换为了tvm的numpy数组。

2.2 解析onnx网络graph的input字段

如前所述,GraphProto中的input数组不仅包含模型的输入,还包含了模型中各节点的权重。也就是将网络的输入节点和各网络节点的权重参数都当作输入

def _parse_graph_input(self, graph):
        for i in graph.input:
            # from onnx v0.2, GraphProto.input has type ValueInfoProto,
            #  and the name is 'i.name'
            # 获取参数的name, shape, type等信息
            i_name, i_shape, d_type, i_shape_name = get_info(i)
            if i_name in self._params:
                # i is a param instead of input
                # 如果是节点参数,则在前面graph.initializer的处理中已经在参数表中添加了对应的节点
                self._num_param += 1
                self._nodes[i_name] = new_var(
                    i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
                )
            elif i_name in self._nodes:
                continue
            else:
                # 如果是模型的输入
                self._num_input += 1
                self._input_names.append(i_name)
                # self._shape是用户在调用from_onnx的时候传入的模型输入shape参数
                if i_name in self._shape:
                    i_shape = self._shape[i_name]
                else:
                    # 模型的输入shape有不定项
                    if "?" in str(i_shape):
                        warning_msg = (
                            "Input %s has unknown dimension shapes: %s. "
                            "Specifying static values may improve performance"
                            % (i_name, str(i_shape_name))
                        )
                        warnings.warn(warning_msg)
                if isinstance(self._dtype, dict):
                    dtype = self._dtype[i_name] if i_name in self._dtype else d_type
                else:
                    dtype = d_type
                # 在nodes表中加入输入节点
                self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
            #记录模型的输入node
            self._inputs[i_name] = self._nodes[i_name]

get_info函数是解析onnx的ValueInfoProto类型数据,获取数据的name、shape、类型等信息。

2.3 检查是否有不支持的算子

调用_check_for_unsupported_ops检查当前的onnx网络中所有算子是不是都能转换为tvm relay ir 

    def _check_for_unsupported_ops(self, graph):
        # 获取onnx算子到tvm relay ir的转换映射表
        convert_map = _get_convert_map(self.opset)
        unsupported_ops = set()
        for node in graph.node:
            op_name = node.op_type
            if (
                op_name not in convert_map
                and op_name != "Constant"
                and op_name not in _identity_list
            ):          
			    # 如果算子不在映射表中,也不在_identity_list表中,则认为当前算子是TVM不支持的
                unsupported_ops.add(op_name)
        # 如果有不支持的算子,则转换失败
        if unsupported_ops:
            msg = "The following operators are not supported for frontend ONNX: "
            msg += ", ".join(unsupported_ops)
            raise tvm.error.OpNotImplemented(msg)

 _get_convert_map根据(调用from_onnx时传入的)版本号,返回一个表,每个表单元的索引是onnx算子名称,key值是该算子转换为tvm relay ir形式的接口。如果某个onnx算子没有对应的转换接口,就认为tvm当前不支持该算子。具体转换细节可以参考onnx到tvm relay ir的转换。

2.4 创建网络的DAG

然后调用_construct_nodes函数,解析onnx网络的各个节点以及节点连接关系,在tvm中创建网络的DAG(有向无环图) 

    def _construct_nodes(self, graph):
        """Nodes are stored as directed acyclic graph."""
        # 遍历onnx模型的每个算子节点
        for node in graph.node:
            #算子名称,不是节点名称
            op_name = node.op_type
            #解析节点属性
            attr = self._parse_attr(node.attribute)
            # Create and populate input list.
            inputs = onnx_input()
            # 获取节点的所有输入
            for i in node.input:
                if i != "":
                    inputs.append(self._nodes[self._renames.get(i, i)])
                else:
                    inputs.append(None)
            i_name = self._parse_value_proto(node)
            # 获取节点的输出
            node_output = self._fix_outputs(op_name, node.output)
            attr["tvm_custom"] = {}
            attr["tvm_custom"]["name"] = i_name
            attr["tvm_custom"]["num_outputs"] = len(node_output)
            # 将当前onnx节点转换为tvm relay ir
            op = self._convert_operator(op_name, inputs, attr, self.opset)
            if not isinstance(op, _expr.TupleWrapper):
                outputs_num = 1
            else:
                outputs_num = len(op)

            if outputs_num == 1:
                op = fold_constant(op)
            else:
                op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))

            if outputs_num > 1:
                # ONNX supports optional outputs for some nodes.
                # This block searches for missing outputs in the ONNX graph
                # and removes any unneeded ops
                #下面这段代码的意思是:onnx支持一个节点有多个输出,但是有些输出并不实际使用.在转换为tvm relay ir的时候,我们将这些输出剔除掉.具体做法如下:
                # 获取节点的有效输出.如果某个输出没有名字,那么认为这个输出没有被(自己或者其他节点)使用,是无效输出
                valid_outputs = [False] * outputs_num
                for i, output in enumerate(node_output):
                    if output != "":
                        valid_outputs[i] = True
                # If we have outputs ONNX isn't expecting, we need to drop them
                # 如果节点有输出是无效
                if not all(valid_outputs):
                    # 这里op为onnx转换后的tvm表示
                    tup = op.astuple()
                    # TupleWrapper can also wrap ops with TupleType outputs
                    # 从tvm表达式中将有效输出对应的部分挑出来,组成当前节点的实际输出
                    if isinstance(tup, _expr.Tuple):
                        # For tuples, we extract the fields instead of using GetTupleItem
                        outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid]
                    else:
                        # For call nodes, we need to GetTupleItem
                        outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid]
                    # Create the new op with valid outputs
                    if len(outputs) == 1:
                        op = outputs[0]
                    # 如果有多个输出并且有无效输出被剔除,需要重新打包当前节点的tvm relay ir
                    elif len(outputs) != outputs_num:
                        op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs))
                    # Drop invalid outputs for the onnx node
                    # 更新onnx 节点的输出表, string类型
                    outputs_num = len(outputs)                    
                    node_output = [output for output in node_output if output != ""]
            assert (
                len(node_output) == outputs_num
            ), "Number of output mismatch {} vs {} in {}.".format(
                len(node_output), outputs_num, op_name
            )
            
            # 将输出加入节点表, 节点的值为节点的tvm表示
            if outputs_num == 1:
                self._nodes[node_output[0]] = op
            else:
                for k, i in zip(list(node_output), range(len(node_output))):
                    self._nodes[k] = op[i]

代码中调用_convert_operator将onnx算子转换为tvm relay ir。(前面_get_convert_map得到的是各个onnx算子的转换接口,接口执行后得到的才是tvm relay ir)。详细的转换流程见onnx到tvm relay ir的转换。

2.5 处理onnx模型输出

        # now return the outputs
        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
        # If requested, directly return the converted expressions.
        if get_output_expr:
            return outputs

graph.output是onnx模型的输出,是一个onnx ValueInfoProto数组。_parse_vale_proto(i)返回输出i的名字。而当前self._nodes中是每个节点的输出的tvm relay ir,里面当然也就有网络最后一个节点的。所以第一行的outputs返回了网络所有输出节点的tvm relay ir。因为某个节点的输入是上一个节点的输出,而这上一个节点的输出也是tvm relay ir,被带入当前节点。例如某一个节点的tvm relay ir:

################################################
onnx op node  Convolution110
output:  ['Convolution110_Output_0']
convert to tvm op:  <class 'tvm.relay.expr.Call'>
free_var %Input3: Tensor[(1, 1, 28, 28), float32];
%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];
%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);
free_var %Parameter6: Tensor[(8, 1, 1), float32];
%2 = add(%1, %Parameter6);
%3 = nn.relu(%2);
%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);
%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];
nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5])
#####################################################

数据流向上的下一个节点:

################################################
onnx op node  Plus112
output:  ['Plus112_Output_0']
convert to tvm op:  <class 'tvm.relay.expr.Call'>
free_var %Input3: Tensor[(1, 1, 28, 28), float32];
%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];
%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);
free_var %Parameter6: Tensor[(8, 1, 1), float32];
%2 = add(%1, %Parameter6);
%3 = nn.relu(%2);
%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);
%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];
%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);
free_var %Parameter88: Tensor[(16, 1, 1), float32];
add(%6, %Parameter88)
################################################

其实这里下游节点只是最后一行那个add算子,其他的都是上游,从网络输入结点开始的各个节点的累加。 

这样按照数据流向叠加到最后整个模型的输出,得到的就是整个模型的tvm relay ir。这里代码第一行的outputs得到的就是整个网络的tvm relay ir。第二行将其打包为一个tuple

如果我们在调用GraphProto.from_onnx的时候传入的get_output_expr参数为true,那么模型转换就到此为止了,返回的是模型的tvm relay ir。但是我们在编译运行模型的脚本中调用的relay.frontend.from_onnx接口(这个接口里面调用了GraphProto.from_onnx)没有这个参数,所以这里不会返回。

2.6 打包模型转换输出

        ## Maintain the order of inputs and parameters from the ONNX graph, but only include
        ## those parameters that are needed to execute the relay graph
        free_vars = analysis.free_vars(outputs)
        nodes = {v: k for k, v in self._nodes.items()}
        free_vars = [nodes[var] for var in free_vars]
        for i_name in self._params:
            if i_name in free_vars and i_name not in self._inputs:
                self._inputs[i_name] = self._nodes[i_name]
        # Create a function from our output expression and all input variables.
        func = _function.Function([v for k, v in self._inputs.items()], outputs)
        return IRModule.from_expr(func), self._params

 这里的流程:

1. 调用python/tvm/relay/analysis/analysis.py的free_vars接口,采用post DFS算法,从网络的输出开始遍历网络的tvm relay ir,找到free变量(是什么东西?按照官方文档,应该是权重之类的);

2. 然后从节点表中获取这些fee变量对应的节点;

3. 将这些节点加入网络的输入表_inputs中;

4. 调用_function.Function,传入网络的输入,参数和转换后的网络表示,得到_func;

5. 最后返回网络的tvm表达和权重参数

返回的网络tvm表达是什么样子呢?我们可以在https://blog.csdn.net/zx_ros/article/details/125894033的模型编译运行脚本中直接打印返回的mod看看:

......
shape_dict = {input_name: data.shape}
# 导入onnx模型
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
print(mod)

...

输出

dl@dl:~/tvm_learning$ python3 mnist_onnx.py 
/home/dl/tvm/python/tvm/driver/build_module.py:267: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  warnings.warn(
def @main(%Input3: Tensor[(1, 1, 28, 28), float32] /* ty=Tensor[(1, 1, 28, 28), float32] */) -> Tensor[(1, 10), float32] {
  %0 = nn.pad(%Input3, 0f /* ty=float32 */, pad_width=[[0i64, 0i64], [0i64, 0i64], [2i64, 2i64], [2i64, 2i64]]) /* ty=Tensor[(1, 1, 32, 32), float32] */;
  %1 = nn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(8, 1, 5, 5), float32] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]) /* ty=Tensor[(1, 8, 28, 28), float32] */;
  %2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(8, 1, 1), float32] */) /* ty=Tensor[(1, 8, 28, 28), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(1, 8, 28, 28), float32] */;
  %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 8, 14, 14), float32] */;
  %5 = nn.pad(%4, 0f /* ty=float32 */, pad_width=[[0i64, 0i64], [0i64, 0i64], [2i64, 2i64], [2i64, 2i64]]) /* ty=Tensor[(1, 8, 18, 18), float32] */;
  %6 = nn.conv2d(%5, meta[relay.Constant][2] /* ty=Tensor[(16, 8, 5, 5), float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]) /* ty=Tensor[(1, 16, 14, 14), float32] */;
  %7 = add(%6, meta[relay.Constant][3] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 14, 14), float32] */;
  %8 = nn.relu(%7) /* ty=Tensor[(1, 16, 14, 14), float32] */;
  %9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %10 = reshape(%9, newshape=[1, 256]) /* ty=Tensor[(1, 256), float32] */;
  %11 = nn.dense(%10, meta[relay.Constant][4] /* ty=Tensor[(10, 256), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(1, 10), float32] */;
  add(%11, meta[relay.Constant][5] /* ty=Tensor[(1, 10), float32] */) /* ty=Tensor[(1, 10), float32] */
}

如果进一步探索下_function.Function都干了些什么,可以看到最后是进入到C++代码中,执行了下面的C++ lamabda函数,返回一个C++的Function句柄:

TVM_REGISTER_GLOBAL("relay.ir.Function")
    .set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,
                       tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
      return Function(params, body, ret_type, ty_params, attrs);
    });

同样的,IRModule.from_expr最后调用的是C++代码中TVM_REGISTER_GLOBAL("ir.Module_FromExpr")注册的接口,接口函数执行IRModule::FromExpr,FromExpr调用IRModule::FromExprInContext,生成一个C++端的IRModule实例

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【TVM源码学习笔记】2 模型导入from_onnx 的相关文章

  • 如何将JavaWeb项目部署到Linux服务器

    将JavaWeb项目部署到服务器需要先在服务器安装tomcat 数据库 Java环境 没有安装的同学先装好以上三件套 当装好这三样后就可以开始部署JavaWeb项目了 其实很简单 将项目打包成war文件后上传到tomcat下的webapps
  • 芯片学生党必会的行业英文术语

    转载至芯职业公众号 芯片领域有不少英文缩写术语 对学生党与初学者颇为费解 严重的还会给人 劝退 感 因此 在这个系列 我们将介绍一些常用的英文术语 旨在让大家了解这些英文的同时对芯片产业的全貌有一个大体的认识 并不作过分深究 Wafer D
  • Android 用surfaceview模拟帧动画的效果,解决帧动画的OOM问题

    最近做的项目 客户临时要求改版 我真的是最烦这个 要求跟换主页面的背景 换上新的背景图 要求是动态的 效果 我随便拿的五个图片做的gif 方案 帧动画方案 缺点 1 好像只能imageview才能播放帧动画 2 容易OOM 播三四张还行 播
  • 9款超级实用 VSCode 插件,让 Python 编程轻松愉悦

    1 Python preview Python Preview是一个适用于VSCode的Python代码预览插件 可以将Python代码转换为漂亮的HTML页面 并在浏览器中进行预览 通过该插件 程序员可以在VSCode中方便地预览Pyth
  • 点云Las文件读写c++库 Lasib_msvc2015

    点云Las文件读写c 库 Lasib msvc2015 前言 去官网下载laslib源码 发现编译错误 需要以下的几个依赖库 1 在进行编译之前我们首先需要编译Boost GDAL TIFF LASZIP和GeoTIFF的编译 大家可以参考

随机推荐