条件随机场之CRF++源码详解-预测

  这篇文章主要讲解CRF++实现预测的过程,预测的算法以及代码实现相对来说比较简单,所以这篇文章理解起来也会比上一篇条件随机场训练的内容要容易。

预测

  上一篇条件随机场训练的源码详解中,有一个地方并没有介绍。 就是训练结束后,会把待优化权重alpha等变量保存到文件中,也就是输出到指定的模型文件。在执行预测的时候会从模型文件读出相关的变量,这个过程其实就是数据序列化与反序列化,该过程跟条件随机场算法关系不大,因此为了突出重点源码解析里就没有介绍这部分,有兴趣的朋友可以自己研究一下。

  CRF++预测的入口代码在crf_test.cpp的main函数中,最终会调用tragger.cpp的int crfpp_test(const Param &param)函数,期间会做一些输入参数的处理、异常处理、读取模型文件等操作。一切准备就绪就会打开待预测的文件,进行预测。正式探讨预测代码之前,我们先看下预测的理论基础。条件随机场的预测用到了维特比算法,公式如下:

\begin{aligned} y^* &= \arg \max_yP_w(y|x) \\ &=  \arg \max_y\frac{ \exp \left \{ \sum_{k=1}^Kw_kf_k(y,x) \right\}}{Z_w(x)} \\ &=  \arg \max_y \exp \left \{\sum_{k=1}^Kw_kf_k(y,x) \right\} \\ &= \arg \max_y \ \sum_{k=1}^Kw_kf_k(y,x) \end{aligned}

从公式我们可以看出,我们求的概率最大值就是要求代价最大。接下来就看下CRF++的源码,代码在tragger.cpp的crfpp_test函数中:

while (*is) {//is是打开的测试文件,可以输入多个测试文件做预测
      tagger.parse_stream(is.get(), os.get());
}

bool TaggerImpl::parse_stream(std::istream *is,
                              std::ostream *os) {
  if (!read(is) || !parse()) {//read函数在特征篇讲过,不再赘述,调用parse函数进行预测
    return false;
  }
  if (x_.empty()) {
    return true;
  }
  toString(); //格式化输出,-v 会输出每个词预测为某个label的概率,-n会输出预测序列概率最大的前n个,如果理解上一篇训练过程,再看这个函数就比较容易理解,无非就是概率计算,这里不再赘述
  os->write(os_.data(), os_.size()); //输出到输出文件
  return true;
}
bool TaggerImpl::parse() {
  CHECK_FALSE(feature_index_->buildFeatures(this)) //构建特征,同特征篇代码,不再赘述
      << feature_index_->what();

  if (x_.empty()) {
    return true;
  }
  buildLattice(); //构建无向图,因为要计算代价最大的序列,训练篇讲过,不再赘述
  if (nbest_ || vlevel_ >= 1) {
    forwardbackward(); //前向后向算法,为了计算单词节点的概率,训练篇讲过,不再赘述
  }
  viterbi();  //维特比算法, 做预测的代码
  if (nbest_) {
    initNbest();
  }

  return true;
}
void TaggerImpl::viterbi() {
  for (size_t i = 0;   i < x_.size(); ++i) { //遍历每个词
    for (size_t j = 0; j < ysize_; ++j) { //遍历每个词的每个label
      double bestc = -1e37;
      Node *best = 0;
      const std::vector<Path *> &lpath = node_[i][j]->lpath;
      for (const_Path_iterator it = lpath.begin(); it != lpath.end(); ++it) { //从前一个词到当前词的代价之和 = max(前一个节点的代价 + 前一个节点的边代价 + 当前节点代价)
        double cost = (*it)->lnode->bestCost +(*it)->cost +
            node_[i][j]->cost;
        if (cost > bestc) { //记录截止当前节点最大的代价, 以及对应的前一个节点
          bestc = cost;
          best  = (*it)->lnode;
        }
      }
      node_[i][j]->prev     = best; //记录前一个几点
      node_[i][j]->bestCost = best ? bestc : node_[i][j]->cost; //记录最大的代价值, 如果best = 0代表第一个词,没有左边,最大代价就是节点的代价node_[i][j]->cost
    }
  }

  double bestc = -1e37;
  Node *best = 0;
  size_t s = x_.size()-1;
  for (size_t j = 0; j < ysize_; ++j) { //遍历最后一个词的节点,截止到最后一个词的代价最大值就是整个句子的最大代价
    if (bestc < node_[s][j]->bestCost) {
      best  = node_[s][j];
      bestc = node_[s][j]->bestCost;
    }
  }

  for (Node *n = best; n; n = n->prev) {//记录代价最大的预测序列
    result_[n->x] = n->y;
  }

  cost_ = -node_[x_.size()-1][result_[x_.size()-1]]->bestCost;
}

预测的核心代码就看完了,大部分复用了训练过程的逻辑。可以看到预测的过程跟公式是一致的,无非就是求能够让代价最大的label序列(标记序列),这就是维特比算法。

总结

  至此,我们的条件随机场之CRF++源码详解系列就结束了,主要涵盖了特征处理、训练以及预测三个核心过程。结合CRF++源码我们可以更形象的、更通俗的去理解条件随机场模型。以后想起条件随机场模型,我们脑海浮现的不再是一堆公式,而是一个无向图,在图上进行代价计算、前向后向计算、期望值的计算以及梯度的计算等一系列的过程。希望这个系列对于正在学习条件随机场的朋友能有帮助,如果本文阐述的有歧义、不通俗、不容易理解的地方,欢迎留言区交流,我将及时更正、回复,希望我们一起提高。

原文地址:https://www.cnblogs.com/duma/p/10344232.html

时间: 2024-11-02 13:05:19

条件随机场之CRF++源码详解-预测的相关文章

条件随机场之CRF++源码详解-开篇

介绍 最近在用条件随机场做切分标注相关的工作,系统学习了下条件随机场模型.能够理解推导过程,但还是比较抽象.因此想研究下模型实现的具体过程,比如:1) 状态特征和转移特征具体是什么以及如何构造 2)前向后向算法具体怎么实现 等等.那么,想要深入了解一个算法比较好的方式就是阅读现有的开源项目.阅读好的开源项目不但可以深入理解原理,还可以学习一些工程实践的经验.我阅读条件随机场的开源项目是CRF++.我在阅读CRF++源码的时候走过一些弯路也积累了一些经验,想把这个过程和经验总结下来,希望能够对正在

条件随机场之CRF++源码详解-训练

上篇的CRF++源码阅读中, 我们看到CRF++如何处理样本以及如何构造特征.本篇文章将继续探讨CRF++的源码,并且本篇文章将是整个系列的重点,会介绍条件随机场中如何构造无向图.前向后向算法.如何计算条件概率.如何计算特征函数的期望以及如何求似然函数的梯度.本篇将结合条件随机场公式推导和CRF++源码实现来讲解以上问题.原文链接 开启多线程 我们接着上一篇encoder.cpp文件中的learn函数继续看,该函数的下半部分将会调用具体的学习算法做训练.目前CRF++支持两种训练算法,一种是拟牛

Java concurrent AQS 源码详解

一.引言 AQS(同步阻塞队列)是concurrent包下锁机制实现的基础,相信大家在读完本篇博客后会对AQS框架有一个较为清晰的认识 这篇博客主要针对AbstractQueuedSynchronizer的源码进行分析,大致分为三个部分: 静态内部类Node的解析 重要常量以及字段的解析 重要方法的源码详解. 所有的分析仅基于个人的理解,若有不正之处,请谅解和批评指正,不胜感激!!! 二.Node解析 AQS在内部维护了一个同步阻塞队列,下面简称sync queue,该队列的元素即静态内部类No

深入Java基础(四)--哈希表(1)HashMap应用及源码详解

继续深入Java基础系列.今天是研究下哈希表,毕竟我们很多应用层的查找存储框架都是哈希作为它的根数据结构进行封装的嘛. 本系列: (1)深入Java基础(一)--基本数据类型及其包装类 (2)深入Java基础(二)--字符串家族 (3)深入Java基础(三)–集合(1)集合父类以及父接口源码及理解 (4)深入Java基础(三)–集合(2)ArrayList和其继承树源码解析以及其注意事项 文章结构:(1)哈希概述及HashMap应用:(2)HashMap源码分析:(3)再次总结关键点 一.哈希概

Android View 事件分发机制源码详解(View篇)

前言 在Android View 事件分发机制源码详解(ViewGroup篇)一文中,主要对ViewGroup#dispatchTouchEvent的源码做了相应的解析,其中说到在ViewGroup把事件传递给子View的时候,会调用子View的dispatchTouchEvent,这时分两种情况,如果子View也是一个ViewGroup那么再执行同样的流程继续把事件分发下去,即调用ViewGroup#dispatchTouchEvent:如果子View只是单纯的一个View,那么调用的是Vie

Android编程之Fragment动画加载方法源码详解

上次谈到了Fragment动画加载的异常问题,今天再聊聊它的动画加载loadAnimation的实现源代码: Animation loadAnimation(Fragment fragment, int transit, boolean enter, int transitionStyle) { 接下来具体看一下里面的源码部分,我将一部分一部分的讲解,首先是: Animation animObj = fragment.onCreateAnimation(transit, enter, fragm

Spring IOC源码详解之容器依赖注入

Spring IOC源码详解之容器依赖注入 上一篇博客中介绍了IOC容器的初始化,通过源码分析大致了解了IOC容器初始化的一些知识,先简单回顾下上篇的内容 载入bean定义文件的过程,这个过程是通过BeanDefinitionReader来完成的,其中通过 loadBeanDefinition()来对定义文件进行解析和根据Spring定义的bean规则进行处理 - 事实上和Spring定义的bean规则相关的处理是在BeanDefinitionParserDelegate中完成的,完成这个处理需

Spring IOC源码详解之容器初始化

Spring IOC源码详解之容器初始化 上篇介绍了Spring IOC的大致体系类图,先来看一段简短的代码,使用IOC比较典型的代码 ClassPathResource res = new ClassPathResource("beans.xml"); DefaultListableBeanFactory factory = new DefaultListableBeanFactory(); XmlBeanDefinitionReader reader = new XmlBeanDe

IntentService源码详解

IntentService可以做什么: 如果你有一个任务,分成n个子任务,需要它们按照顺序完成.如果需要放到一个服务中完成,那么IntentService就会使最好的选择. IntentService是什么: IntentService是一个Service(看起来像废话,但是我第一眼看到这个名字,首先注意的是Intent啊.),所以如果自定义一个IntentService的话,一定要在AndroidManifest.xml里面声明. 从上面的"可以做什么"我们大概可以猜测一下Inten