ncnn源码分析-004-代码流程总结

0.调用实例

先看一个调用实例,顺着调用流程探寻ncnn内部具体实现细节。

#include "net.h"

int main(int argc, char **argv)
{
    ncnn::Mat in;
    ncnn::Mat out;

    ncnn::Net net;
    net.load_param("model.param");
    net.load_model("model.bin");

    ncnn::Extractor ex = net.create_extractor();
    ex.set_light_mode(true);
    ex.set_num_threads(4);

    ex.input("data", in);
    ex.extract("output", out);

    return 0;
}

1.blob结构

class Blob
{
public:
    std::string name; //blob名字
    int producer; //指明该blob来自哪个层的输出,层索引
    std::vector<int> consumers; //指明该blob作为哪个层的输入,层索引
}

在blob的构造函数中初始化producer=-1

2.layer

class Layer
{
public:
    int typeindex; //类型ID
    std::string type; //类型名字
    std::string name; //层的名字
    std::vector<int> bottoms; //当前层所有输入blob的索引
    std::vector<int> tops; //当前层所有输出blob索引

    int load_param(const ParamDict &pd);
    int load_model(const ModelBin &mb);
    int forward(const std::vector<Mat> &bottom_blobs, std::vector<Mat> &top_blobs, const Option &opt = get_default_option());
}
  • layer进行前向传播时,根据bottoms索引值找到bottom数据,作为forward的输入,计算结果存入tops对应的blob里,完成一层的inferecnce
  • load_param和load_model有三种定义
    • 第一个是在net里定义为int load_param(FILE *fp);,主要功能是从param文件中读取数据,网络层数,网络blob个数,每一个输入输出blob的类型、名字等信息。涉及到网络结构的参数(不是训练参数),例如滤波器个数、padding、stride等信息由ParamDict里的load_param负责读取。
    • 第二个是在ParamDict里,定义为int load_param(FILE *fp);fp位置接net里的位置,该函数将参数读取到ParamDict的一个类实例pd里,以pair对的形式存储,不考虑具体参数含义,只需按照key,value存储你即可。
    • 第三个是在layer里,定义为int load_param(const ParamDict &pd);这个load_param负责从pair对里根据不同层对key的定义解析成和每一个层对应的参数,参数的不同决定了相同类型层的差异性,比如同样是卷积层,但是滤波器个数不同。
    • load_model和load_param类似,至此完成了整个网络的解析工作。

以上内容对应于我们平时使用ncnn的以下代码形式:

ncnn::Net net;
net.load_param("model.param");
net.load_model("model.bin");

2.net解析

class Net
{
public:
    int usewinograd_convolution;
    int use_sgemm_convolution;
    int use_int8_inference;
    int use_vulkan_compute;

    int load_param(FILE *fp);
    int load_model(FILE *fp);
    Extractor create_extractor();
protected:
    std::vector<Blob> blobs;//网络的所有blob
    std::vector<Layer*> layers;//网络的所有层指针

    int forward_layer(int layer_index, std::vector<Mat> &blob_mats, Options &opt);
    int find_blob_index_by_name(const char* name);
    int find_layer_index_by_name(const char *name);
}

class Extractor
{
public:
    int Extractor::input(const char *blob_name, const Mat &in);
    int Extractor::input(int blob_index, VkMat &feat, VkCompute &cmd);
    int Extractor::extract(const char *blob_name, Mat &feat);
    int Extractor::extract(blob_index, const Mat &feat);//次函数直接forward_layer()
protected:
    friend Extractor Net::create_extractor() const;
    Extractor(const Net *net, int blob_count);

private:
    const Net *net;
    std::vector<Mat> blob_mats;
    Option opt;
}
  • Net类里的成员变量包含了另一个类create_extractor方法,该方法实际上就是调用Extractor类的构造函数,返回一个Extractor实例。
Extractor Net::create_extractor() const
{
    return Extractor(this, blobs.size());
}

调用Extractor::input(const char *blob_name, const Mat &in)设置输入数据,这里比较简单,通过输入blob名字找到对应的索引,然后根据索引取到真实的blob数据。

  • 整个网络的核心是Extractor::extract(const char *blob_name, Mat &feat)

    • 该函数调用另一个重载函数Extractor::extract(blob_index, const Mat &feat)两个输入参数分别是要获取数据blob索引和存放数据的输出变量,通过blob_index在blobs(net的类成员变量,用于存放整个网络的所有blob)找到对应的blob
    • 调用blob的类成员变量producer来找到该blob是哪个layer的输出,也就是layer_index
    • 接下来调用forward_layer(layer_index, blob_mats, opt)完成整个网络的前向传播,逐层前传使用递归完成,blob_mats里存放真正的blob数据,是net的私有成员变量std::vector<Mat> blob_mats,Mat是自定义类型
    • 完成forward_layer后,feat = blob_mats[blob_index],feat是调用例子中out的引用,将blob数据存放在feat变量中,整个流程结束

原文地址:https://www.cnblogs.com/ganchunsheng/p/11686703.html

时间: 2024-10-04 14:13:40

ncnn源码分析-004-代码流程总结的相关文章

HBase1.0.0源码分析之请求处理流程分析以Put操作为例(二)

HBase1.0.0源码分析之请求处理流程分析以Put操作为例(二) 1.通过mutate(put)操作,将单个put操作添加到缓冲操作中,这些缓冲操作其实就是Put的父类的一个List的集合.如下: private List<Row> writeAsyncBuffer = new LinkedList<>(); writeAsyncBuffer.add(m); 当writeAsyncBuffer满了之后或者是人为的调用backgroundFlushCommits操作促使缓冲池中的

openVswitch(OVS)源码分析之工作流程(哈希桶结构体的解释)

这篇blog是专门解决前篇openVswitch(OVS)源码分析之工作流程(哈希桶结构体的疑惑)中提到的哈希桶结构flex_array结构体成员变量含义的问题. 引用下前篇blog中分析讨论得到的flex_array结构体成员变量的含义结论: struct { int element_size; // 这是flex_array_part结构体存放的哈希头指针的大小 int total_nr_elements; // 这是全部flex_array_part结构体中的哈希头指针的总个数 int e

Spark SQL源码分析之核心流程

自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人到了几十人,而且发展速度异常迅猛,究其原因,个人认为有以下2点: 1.整合:将SQL类型的查询语言整合到 Spark 的核心RDD概念里.这样可以应用于多种任务,流处理,批处理,包括机器学习里都可以引入Sql. 2.效率:因为Shark受到hive的编程模型限制,无法再继续优化来适应Spark模型里. 前一段时间测试过Shark,并且对Spark

leveldb源码分析--插入删除流程

由于网络上对leveldb的分析文章都比较丰富,一些基础概念和模型都介绍得比较多,所以本人就不再对这些概念以专门的篇幅进行介绍,本文主要以代码流程注释的方式. 首先我们从db的插入和删除开始以对整个体系有一个感性的认识,首先看插入: Status DB::Put(const WriteOptions& opt, const Slice& key, const Slice& value) { WriteBatch batch; //leveldb中不管单个插入还是多个插入都是以Wri

Okhttp源码分析--基本使用流程分析

Okhttp源码分析--基本使用流程分析 一. 使用 同步请求 OkHttpClient okHttpClient=new OkHttpClient(); Request request=new Request.Builder() .get() .url("www.baidu.com") .build(); Call call =okHttpClient.newCall(request).execute(); 异步请求 OkHttpClient okHttpClient=new OkH

MyBatis源码分析-MyBatis初始化流程

MyBatis 是支持定制化 SQL.存储过程以及高级映射的优秀的持久层框架.MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集.MyBatis 可以对配置和原生Map使用简单的 XML 或注解,将接口和 Java 的 POJOs(Plain Old Java Objects,普通的 Java对象)映射成数据库中的记录.如何新建MyBatis源码工程请点击MyBatis源码分析-IDEA新建MyBatis源码工程. MyBatis初始化的过程也就是创建Configura

HBase1.0.0源码分析之请求处理流程分析以Put操作为例(一)

如下面的代码所示,是HBase Put操作的简单代码实例,关于代码中的Connection connection = ConnectionFactory.createConnection(conf),已近在前一篇博 HBase1.0.0源码分析之Client启动连接流程,中介绍了链接的相关流程以及所启动的服务信息. TableName tn = TableName.valueOf("test010"); try (Connection connection = ConnectionFa

第一篇:Spark SQL源码分析之核心流程

/** Spark SQL源码分析系列文章*/ 自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人到了几十人,而且发展速度异常迅猛,究其原因,个人认为有以下2点: 1.整合:将SQL类型的查询语言整合到 Spark 的核心RDD概念里.这样可以应用于多种任务,流处理,批处理,包括机器学习里都可以引入Sql.    2.效率:因为Shark受到hive的编程模型限制,无法再继续优化来适应Spark

Monkey源码分析之运行流程

在<MonkeyRunner源码分析之与Android设备通讯方式>中,我们谈及到MonkeyRunner控制目标android设备有多种方法,其中之一就是在目标机器启动一个monkey服务来监听指定的一个端口,然后monkeyrunner再连接上这个端口来发送命令,驱动monkey去完成相应的工作. 当时我们只分析了monkeyrunner这个客户端的代码是怎么实现这一点的,但没有谈monkey那边是如何接受命令,接受到命令又是如何处理的. 所以自己打开源码看了一个晚上,大概有了概念.但今天

SpringMVC源码分析-400异常处理流程及解决方法

本文设计SpringMVC异常处理体系源码分析,SpringMVC异常处理相关类的设计模式,实际工作中异常处理的实践. 问题场景 假设我们的SpringMVC应用中有如下控制器: 代码示例-1 @RestController("/order") public class OrderController{ @RequestMapping("/detail") public Object orderDetail(int orderId){ // ... } } 这个控制