tf.contrib.slim.arg_scope 完整

缘由

  最近一直在看深度学习的代码,又一次看到了slim.arg_scope()的嵌套使用,具体代码如下:

with slim.arg_scope(
      [slim.conv2d, slim.separable_conv2d],
      weights_initializer=tf.truncated_normal_initializer(
          stddev=weights_initializer_stddev),
      activation_fn=activation_fn,
      normalizer_fn=slim.batch_norm if use_batch_norm else None):
    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
      with slim.arg_scope(
          [slim.conv2d],
          weights_regularizer=slim.l2_regularizer(weight_decay)):
        with slim.arg_scope(
            [slim.separable_conv2d],
            weights_regularizer=depthwise_regularizer) as arg_sc:
          return arg_sc

  由上述代码可以看到,第一层argscope有slim.conv2d参数,第三层也有这个参数,那么不同层的参数是如何相互补充,作用到之后的代码块中,就是这篇博文的出发点。

准备工作

  我们先看一下arg_scope的函数声明:

@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):

  有函数修饰符@tf_contextlib.contextmanager修饰arg_scope函数,我们先研究下这个函数修饰符。

@的作用

  @之后一般接一个可调用对象(tf_contextlib.contextmanager),一起构成函数修饰符(装饰器),这个可调用对象将被修饰函数(arg_scope)作为参数,执行一系列辅助操作,我们来看一个demo

import time

def my_time(func):
    print(time.ctime())
    return func()

@my_time  # 从这里可以看出@time 等价于 time(xxx()),但是这种写法你得考虑python代码的执行顺序
def xxx():
    print(‘Hello world!‘)

运行结果:
Wed Jul 26 23:01:21 2017
Hello world!

  在这个例子中,xxx函数实现我们的主要功能,打印Hello world!,但我们想给xxx函数添加一些辅助操作,于是我们用函数修饰符@my_time,使xxx函数先打印时间。整个例子的执行流程为调用my_time可调用对象,它接受xxx函数作为参数,先打印时间,再执行xxx函数。

上下文管理器

  既然arg_scope函数存在装饰器,那么我们应该了解一下装饰器提供了什么辅助功能,代码为:

import contextlib as _contextlib

from tensorflow.python.util import tf_decorator

def contextmanager(target):
  """A tf_decorator-aware wrapper for `contextlib.contextmanager`.
  Usage is identical to `contextlib.contextmanager`.
  Args:
    target: A callable to be wrapped in a contextmanager.
  Returns:
    A callable that can be used inside of a `with` statement.
  """
  context_manager = _contextlib.contextmanager(target)
  return tf_decorator.make_decorator(target, context_manager, ‘contextmanager‘)

  可以看到导入了contextlib库,这个库提供了contextmanager函数,这也是一个装饰器,它使被修饰的函数具有上下文管理器的功能。上下文管理器的功能是在我们执行一段代码块之前做一些准备工作,执行完代码块之后做一些收尾工作,同样先来看一个上下文管理器的例子:

import time

class MyTimer(object):
    def __init__(self, verbose = False):
        self.verbose = verbose

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *unused):
        self.end = time.time()
        self.secs = self.end - self.start
        self.msecs = self.secs * 1000
        if self.verbose:
            print "elapsed time: %f ms" %self.msecs

with MyTimer(True):  print(‘Hello world!‘)

  类MyTimer中的__enter__和__exit__方法分别是准备工作和收尾工作。整个代码的执行过程为:先执行__enter__方法,__enter__方法中的返回值(这个例子中是self)可以用到代码块中,再执行语句块,这个例子中是print函数,最后执行__exit__方法,更多关于上下文管理器的内容可以看,我的例子也是从那copy的。contextlib中实现上下文管理器稍有不同,一样来看个例子:

from contextlib import contextmanager

@contextmanager
def tag(name):
    print "<%s>" % name
    yield
    print "</%s>" % name

>>> with tag("h1"):
...    print "foo"
运行结果:
<h1>
foo
</h1>

  tag函数中yield之前的代码相当于__enter__方法,yield产生的生成器相当于__enter__方法的返回值,yield之后的代码相当于__exit__方法。

arg_scope方法

  这里我把arg_scope方法中代码稍微做了一些精简,代码如下:

arg_scope = [{}]

@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):try:
      current_scope = current_arg_scope().copy()
      for op in list_ops_or_scope:
        key = arg_scope_func_key(op) # op的代号
        if not has_arg_scope(op): # op是否用@slim.add_arg_scope修饰,这会在下一篇中介绍
          raise ValueError(‘%s is not decorated with @add_arg_scope‘,
                           _name_op(op))
        if key in current_scope:
          current_kwargs = current_scope[key].copy()
          current_kwargs.update(kwargs)
          current_scope[key] = current_kwargs
        else:
          current_scope[key] = kwargs.copy()
      _get_arg_stack().append(current_scope)
      yield current_scope
    finally:
      _get_arg_stack().pop()

# demo
with slim.arg_scope(
      [slim.conv2d, slim.separable_conv2d],
      weights_initializer=tf.truncated_normal_initializer(
          stddev=weights_initializer_stddev),
      activation_fn=activation_fn,
      normalizer_fn=slim.batch_norm if use_batch_norm else None):
    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
      with slim.arg_scope(
          [slim.conv2d],
          weights_regularizer=slim.l2_regularizer(weight_decay)):
        with slim.arg_scope(
            [slim.separable_conv2d],
            weights_regularizer=depthwise_regularizer) as arg_sc:
          return arg_sc

  我们沿着demo一步步看,其中arg_scope是一个栈。先看第一层,current_arg_scope()函数返回栈中最后一个元素,此时是空字典{},由于字典为空,所以会把conv2d和separable_conv2d加入字典,此时栈为[{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs}],然后执行接下来的代码块,即第二层with,finally中函数要在代码块执行完后再执行;第二层执行完后栈为[{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs},{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs, ‘batch_norm‘: batch_norm_params}],可以看到是将第一层的字典复制之后检查其中是否有与第二层相同的op,相同的op就把参数更新,不同的op就增加键值对,如这里的batch_norm。

  回到我们开头提到的问题,不同层的参数是如何互相补充的?现在我们可以看到,参数存储在栈中,每叠加一层,就在原有参数基础上把新参数添加上去。

                                            最后编辑于20:54:35 2018-07-23

原文地址:https://www.cnblogs.com/zzy-tf/p/9356883.html

时间: 2024-11-09 11:22:06

tf.contrib.slim.arg_scope 完整的相关文章

tf.contrib.slim的介绍

本文主要参考博客:博客连接 前言基础: 验证本地的tf.contrib.slim模块是否有效: 1 python -c "import tensorflow.contrib.slim as slim;eval=slim.evaluation.evaluate_once" 下载models模块: 下载连接.下载后解压到你设定的文件夹,笔者解压到"E:\TENSORFLOW\models" 找到并且打开文件夹"E:\TENSORFLOW\models\rese

tf.contrib.slim.data数据加载 综述

TF-Slim为了方便加载各种数据类型(如TFRocords或者文本文件)的数据,创建了这个库. Dataset 这里的数据库与通常意义下数据库是不同的,这里数据库是python一个类,它负责将原始数据通过流水线加工成为我们需要的数据格式. TF-Slim defines a dataset to be a set of files (that may or may not be encoded) representing a finite set of samples, and which c

图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np import tensorflow as tf import struct import glob import os from PIL import Image import time __sony__ = 0 __huawei__ = 1 __blackberry__ = 2 __stage_raw2ra

tf中slim的使用

https://blog.csdn.net/Cyiano/article/details/75006883 https://blog.csdn.net/transMaple/article/details/78273560 slim库是tensorflow中的一个高层封装,它将原来很多tf中复杂的函数进一步封装,省去了很多重复的参数,以及平时不会考虑到的参数.可以理解为tensorflow的升级版. 导入方式: import tensorflow as tf import tensorflow.

学习笔记TF044:TF.Contrib组件、统计分布、Layer、性能分析器tfprof

TF.Contrib,开源社区贡献,新功能,内外部测试,根据反馈意见改进性能,改善API友好度,API稳定后,移到TensorFlow核心模块.生产代码,以最新官方教程和API指南参考. 统计分布.TF.contrib.ditributions模块,Bernoulli.Beta.Binomial.Gamma.Ecponential.Normal.Poisson.Uniform等统计分布,统计研究.应用中常用,各种统计.机器学习模型基石,概率模型.图形模型依赖. 每个不同统计分布不同特征.函数,同

tf.contrib.learn.preprocessing.VocabularyProcessor()

tf.contrib.learn.preprocessing.VocabularyProcessor (max_document_length, min_frequency=0, vocabulary=None, tokenizer_fn=None) 参数: max_document_length: 文档的最大长度.如果文本的长度大于最大长度,那么它会被剪切,反之则用0填充. min_frequency: 词频的最小值,出现次数小于最小词频则不会被收录到词表中. vocabulary: Cate

tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别

tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别 https://blog.csdn.net/u014365862/article/details/78238807 MachineLP的Github(欢迎follow):https://github.com/MachineLP 我的GitHub:https://github.com/MachineLP/train_cnn-rnn-attention 自己搭建的一个框架,包含模型有:vgg(vgg16,vg

关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题

这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(object): """RNN配置参数""" # 模型参数 embedding_dim = 100 # 词向量维度 seq_length = 100 # 序列长度 num_classes = 2 # 类别数 vocab_size = 10000 # 词汇表达

深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.contrib.rnn.MultiRNNCell(堆叠多层LSTM) 4.mlstm_cell.zero_state(state初始化) 5.mlstm_cell(进行LSTM求解)

问题:LSTM的输出值output和state是否是一样的 1. rnn.LSTMCell(num_hidden, reuse=tf.get_variable_scope().reuse)  # 构建单层的LSTM网络 参数说明:num_hidden表示隐藏层的个数,reuse表示LSTM的参数进行复用 2.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob) # 表示对rnn的输出层进行dropout 参数说明:cell表示单层的lstm,o