【源码阅读】opencv中opencl版本的dft函数的实现细节

1.函数声明

opencv-3.4.3\modules\core\include\opencv2\core.hpp:2157

CV_EXPORTS_W void dft(InputArray src, OutputArray dst, int flags = 0, int nonzeroRows = 0);

2.函数实现

opencv-3.4.3\modules\core\src\dxt.cpp:3315

void cv::dft( InputArray _src0, OutputArray _dst, int flags, int nonzero_rows )
{
    CV_INSTRUMENT_REGION()

#ifdef HAVE_CLAMDFFT
    CV_OCL_RUN(ocl::haveAmdFft() && ocl::Device::getDefault().type() != ocl::Device::TYPE_CPU &&
            _dst.isUMat() && _src0.dims() <= 2 && nonzero_rows == 0,
               ocl_dft_amdfft(_src0, _dst, flags))
#endif

#ifdef HAVE_OPENCL
    CV_OCL_RUN(_dst.isUMat() && _src0.dims() <= 2,
               ocl_dft(_src0, _dst, flags, nonzero_rows))
#endif

    Mat src0 = _src0.getMat(), src = src0;
    bool inv = (flags & DFT_INVERSE) != 0;
    int type = src.type();
    int depth = src.depth();

    CV_Assert( type == CV_32FC1 || type == CV_32FC2 || type == CV_64FC1 || type == CV_64FC2 );

    // Fail if DFT_COMPLEX_INPUT is specified, but src is not 2 channels.
    CV_Assert( !((flags & DFT_COMPLEX_INPUT) && src.channels() != 2) );

    if( !inv && src.channels() == 1 && (flags & DFT_COMPLEX_OUTPUT) )
        _dst.create( src.size(), CV_MAKETYPE(depth, 2) );
    else if( inv && src.channels() == 2 && (flags & DFT_REAL_OUTPUT) )
        _dst.create( src.size(), depth );
    else
        _dst.create( src.size(), type );

    Mat dst = _dst.getMat();

    int f = 0;
    if (src.isContinuous() && dst.isContinuous())
        f |= CV_HAL_DFT_IS_CONTINUOUS;
    if (inv)
        f |= CV_HAL_DFT_INVERSE;
    if (flags & DFT_ROWS)
        f |= CV_HAL_DFT_ROWS;
    if (flags & DFT_SCALE)
        f |= CV_HAL_DFT_SCALE;
    if (src.data == dst.data)
        f |= CV_HAL_DFT_IS_INPLACE;
    Ptr<hal::DFT2D> c = hal::DFT2D::create(src.cols, src.rows, depth, src.channels(), dst.channels(), f, nonzero_rows);
    c->apply(src.data, src.step, dst.data, dst.step);
}

3. opencl的调用

#ifdef HAVE_OPENCL
    CV_OCL_RUN(_dst.isUMat() && _src0.dims() <= 2,
               ocl_dft(_src0, _dst, flags, nonzero_rows))
#endif

ocl的函数实现:
opencv-3.4.3\modules\core\src\dxt.cpp:2161

static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_rows)
{
    int type = _src.type(), cn = CV_MAT_CN(type), depth = CV_MAT_DEPTH(type);
    Size ssize = _src.size();
    bool doubleSupport = ocl::Device::getDefault().doubleFPConfig() > 0;

    if (!(cn == 1 || cn == 2)
        || !(depth == CV_32F || (depth == CV_64F && doubleSupport))
        || ((flags & DFT_REAL_OUTPUT) && (flags & DFT_COMPLEX_OUTPUT)))
        return false;

    // if is not a multiplication of prime numbers { 2, 3, 5 }
    if (ssize.area() != getOptimalDFTSize(ssize.area()))
        return false;

    UMat src = _src.getUMat();
    bool inv = (flags & DFT_INVERSE) != 0 ? 1 : 0;

    if( nonzero_rows <= 0 || nonzero_rows > _src.rows() )
        nonzero_rows = _src.rows();
    bool is1d = (flags & DFT_ROWS) != 0 || nonzero_rows == 1;

    FftType fftType = determineFFTType(cn == 1, cn == 2,
        (flags & DFT_REAL_OUTPUT) != 0, (flags & DFT_COMPLEX_OUTPUT) != 0, inv);

    UMat output;
    if (fftType == C2C || fftType == R2C)
    {
        // complex output
        _dst.create(src.size(), CV_MAKETYPE(depth, 2));
        output = _dst.getUMat();
    }
    else
    {
        // real output
        if (is1d)
        {
            _dst.create(src.size(), CV_MAKETYPE(depth, 1));
            output = _dst.getUMat();
        }
        else
        {
            _dst.create(src.size(), CV_MAKETYPE(depth, 1));
            output.create(src.size(), CV_MAKETYPE(depth, 2));
        }
    }

    bool result = false;
    if (!inv)
    {
        int nonzero_cols = fftType == R2R ? output.cols/2 + 1 : output.cols;
        result = ocl_dft_rows(src, output, nonzero_rows, flags, fftType);
        if (!is1d)
            result = result && ocl_dft_cols(output, _dst, nonzero_cols, flags, fftType);
    }
    else
    {
        if (fftType == C2C)
        {
            // complex output
            result = ocl_dft_rows(src, output, nonzero_rows, flags, fftType);
            if (!is1d)
                result = result && ocl_dft_cols(output, output, output.cols, flags, fftType);
        }
        else
        {
            if (is1d)
            {
                result = ocl_dft_rows(src, output, nonzero_rows, flags, fftType);
            }
            else
            {
                int nonzero_cols = src.cols/2 + 1;
                result = ocl_dft_cols(src, output, nonzero_cols, flags, fftType);
                result = result && ocl_dft_rows(output, _dst, nonzero_rows, flags, fftType);
            }
        }
    }
    return result;
}

4.ocl_dft()里面的row/col的调用函数

函数原型:

static bool ocl_dft_rows(InputArray _src, OutputArray _dst, int nonzero_rows, int flags, int fftType)
static bool ocl_dft_cols(InputArray _src, OutputArray _dst, int nonzero_cols, int flags, int fftType)

看其中一个的源码:

static bool ocl_dft_rows(InputArray _src, OutputArray _dst, int nonzero_rows, int flags, int fftType)
{
    int type = _src.type(), depth = CV_MAT_DEPTH(type);
    Ptr<OCL_FftPlan> plan = OCL_FftPlanCache::getInstance().getFftPlan(_src.cols(), depth);
    return plan->enqueueTransform(_src, _dst, nonzero_rows, flags, fftType, true);
}

5.fft计算的对象池

每个确定尺寸的fft计算之前,需要建立一系列的初始化数据;如果每次计算相同尺寸都建立这些初始化数据,明显很浪费。
于是建立一个对象池,每出现一个fft计算的新尺寸,就缓存一个对象。空间换时间(但是长期运行场景要注意内存消耗)。

    Ptr<OCL_FftPlan> OCL_FftPlanCache::getFftPlan(int dft_size, int depth)
    {
        int key = (dft_size << 16) | (depth & 0xFFFF);
        std::map<int, Ptr<OCL_FftPlan> >::iterator f = planStorage.find(key);
        if (f != planStorage.end())
        {
            return f->second;
        }
        else
        {
            Ptr<OCL_FftPlan> newPlan = Ptr<OCL_FftPlan>(new OCL_FftPlan(dft_size, depth));
            planStorage[key] = newPlan;
            return newPlan;
        }
    }

6. fft对象

opencv-3.4.3\modules\core\src\dxt.cpp:1881
struct OCL_FftPlan
初始化在构造函数:OCL_FftPlan(int _size, int _depth)
计算使用这个方法: bool enqueueTransform(InputArray _src, OutputArray _dst, int num_dfts, int flags, int fftType, bool rows = true) const
方法的主要代码是构造核函数的编译参数。

6.1 opencl核函数的编译、绑定参数、执行

enqueueTransform()方法的核心代码如下:

        ocl::Kernel k(kernel_name.c_str(), ocl::core::fft_oclsrc, options);
        if (k.empty())
            return false;

        k.args(ocl::KernelArg::ReadOnly(src), ocl::KernelArg::WriteOnly(dst), ocl::KernelArg::ReadOnlyNoSize(twiddles), thread_count, num_dfts);
        return k.run(2, globalsize, localsize, false);

ocl::Kernel 对象用于编译opencl的核函数。
ocl::KernelArg 用于绑定核函数的执行参数。
k.run() 执行核函数。

6.2 核函数的定义

ocl::core::fft_oclsrc 这个常量对象定义了核函数的源码,搜索了所有的.h, .hpp, .cpp都没有找到定义。
源码这部分代码是编译过程生成的。
定义在:
opencv-3.4.3/build/modules/core/opencl_kernels_core.hpp:21

extern struct cv::ocl::internal::ProgramEntry fft_oclsrc;

实现在:
opencv-3.4.3/build/modules/core/opencl_kernels_core.cpp:770

struct cv::ocl::internal::ProgramEntry fft_oclsrc={moduleName, "fft",
"#define SQRT_2 0.707106781188f\n"

看来只是用一个脚本,把opencl的核函数代码转换成为C++字符串而已。

6.3 核函数的定义文件

最终找到opencl fft的核函数的文件:
opencv-3.4.3\modules\core\src\opencl\fft.cl

这里有一个明显的问题,核函数每次调用都要编译一次。并未看见哪里缓存了编译的结果。

7.cv::dft()可能的优化点

  • 每次调用核函数都要编译,应该缓存ocl::Kernel对象
  • 把C函数的风格修改为面向对象风格,把UMat数据upload/核函数运行/UMat数据download等部分都加入异步队列。使得连续计算多个dft()的时候,可以避免CPU等待GPU的结果。

原文地址:https://www.cnblogs.com/ahfuzhang/p/11083423.html

时间: 2024-10-17 00:42:11

【源码阅读】opencv中opencl版本的dft函数的实现细节的相关文章

Android - 源码阅读 - Eclipse中在线阅读

在线阅读Android源码:grepcode 安装Eclipse GrepCode 插件:GrepCode Eclipse Plugin Help->Install New Software... Add: GrepCode http://repository.grepcode.com/java/ext-eclipse 使用 将鼠标置于需要查看的类名或者包上,然后按"F3",弹出 Class File Editor 显示对应class文件. 在Eclipse IDE 环境中在线查

Linux kernel源码阅读笔记2-2.6版本调度器sched.c功能

来自:http://www.ibm.com/developerworks/cn/linux/l-scheduler/ 2.6 版本调度器的源代码都很好地封装到了 /usr/src/linux/kernel/sched.c 文件中.我们在表 1 中对在这个文件中可以找到的一些有用的函数进行了总结. 表 1. Linux 2.6 调度器的功能 函数名 函数说明 schedule 调度器主函数.调度优先级最高的任务执行. load_balance 检查 CPU,查看是否存在不均衡的情况,如果不均衡,就

淘宝数据库OceanBase SQL编译器部分 源码阅读--解析SQL语法树

OceanBase是 阿里巴巴集团自主研发的可扩展的关系型数据库,实现了跨行跨表的事务,支持数千亿条记录.数百TB数据上的SQL操作.在阿里巴巴集团 下,OceanBase数据库支持了多个重要业务的数据存储,包括收藏夹.直通车报表.天猫评价等.截止到2013年4月份,OceanBase线上业务 的数据量已经超过一千亿条. 看起来挺厉害的,今天我们来研究下它的源代码.关于OceanBase的架构描述有很多文档,这篇笔记也不打算涉及这些东西,只讨论OceanBase的SQL编译部分的代码. Ocea

源码阅读的方法

小弟我入行不久,实打实的菜鸟,最近由于个人兴趣和工作需要,读了一些源码,感觉还不错,谨以此文做个小小的总结以达到抛砖引玉之效,如有错误和不足的地方希望各位补充. 感谢开源,让我这种并没有受过系统的软件开发训练的工程师也能学习到业界一流的代码,并通过源代码和一些顶尖的程序员零距离的对话.源码对于我这种经验算不上丰富的小白来说是恐怖的,但真正开始的时候却也是魅力无限的,当全身心地沉浸在代码中时,专注和兴奋度远大于听一次讲座或者看一本书,但如果方法不对则很有可能刚刚形成的勇气和兴趣会被无情地摧毁. 我

muduo2.0源码阅读记录

花了20天的时间读了陈硕先生的<Linux多线程服务端编程>一书的前8章.当然,每天阅读的时间并不算多,中间有些部分也反反复复看了几遍,最后也算是能勉强接受作者传授的知识.配合书把muduo2.0网络部分的代码和日志库代码细读了一遍,这也算是个人第一次较为深入地去读取一个开源项目源码.通过书和源码的阅读,确实是对不少东西加深了理解. 本来想按自己的理解来写源码阅读笔记的,但考虑到网上关于muduo代码的解析文章已经很多并且写的很好了,就放弃了这个想法.摘录几个自己在源码阅读过程中参考的网页:

源码阅读笔记 - 1 MSVC2015中的std::sort

大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格式化,去掉或者展开用于条件编译或者debug检查的宏,依重要程度重新排序函数,但是不会改变命名方式(虽然MSVC的STL命名实在是我不能接受的那种),对于代码块的解释会在代码块前(上面)用注释标明. template<class _RanIt, class _Diff, class _Pr> in

CI框架源码阅读笔记3 全局函数Common.php

从本篇开始,将深入CI框架的内部,一步步去探索这个框架的实现.结构和设计. Common.php文件定义了一系列的全局函数(一般来说,全局函数具有最高的加载优先权,因此大多数的框架中BootStrap引导文件都会最先引入全局函数,以便于之后的处理工作). 打开Common.php中,第一行代码就非常诡异: if ( ! defined('BASEPATH')) exit('No direct script access allowed'); 上一篇(CI框架源码阅读笔记2 一切的入口 index

如何阅读Java源码 阅读java的真实体会

刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比方吧,如果你从来没有学过Java,或是任何一门编程语言如C++,一开始去啃<Core Java>,你是很难从中吸收到营养的,特别是<深入Java虚拟机>这类书,别人觉得好,未必适合现在的你. 虽然Tomcat的源码很漂亮,但我绝不建议你一开始就读它.我文中会专门谈到这个,暂时不展开. 强烈

淘宝数据库OceanBase SQL编译器部分 源码阅读--生成物理查询计划

SQL编译解析三部曲分为:构建语法树,制定逻辑计划,生成物理执行计划.前两个步骤请参见我的博客<<淘宝数据库OceanBase SQL编译器部分 源码阅读--解析SQL语法树>>和<<淘宝数据库OceanBase SQL编译器部分 源码阅读--生成逻辑计划>>.这篇博客主要研究第三步,生成物理查询计划. 一. 什么是物理查询计划 与之前的阅读方法一致,这篇博客的两个主要问题是what 和how.那么什么是物理查询计划?物理查询计划能够直接执行并返回数据结果数