博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow中的共享变量(sharing variables) 最佳方式variable_scope()命名空间来完成
阅读量:4262 次
发布时间:2019-05-26

本文共 3581 字,大约阅读时间需要 11 分钟。

当训练复杂模型时,可能经常需要共享大量的变量。例如,使用测试集来测试已训练好的模型性能表现时,需要共享已训练好模型的变量,如全连接层的权值。

而且我们还会遇到以下问题:

比如,我们创建了一个简单的图像滤波器模型。如果只使用tf.Variable,那么我们的模型可能如下

复制代码

def my_image_filter(input_images):    conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),        name="conv1_weights")    conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases")    conv1 = tf.nn.conv2d(input_images, conv1_weights,        strides=[1, 1, 1, 1], padding='SAME')    relu1 = tf.nn.relu(conv1 + conv1_biases)    conv2_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),        name="conv2_weights")    conv2_biases = tf.Variable(tf.zeros([32]), name="conv2_biases")    conv2 = tf.nn.conv2d(relu1, conv2_weights,        strides=[1, 1, 1, 1], padding='SAME')    return tf.nn.relu(conv2 + conv2_biases)

复制代码

这个模型中有4个不同的变量:conv1_weightsconv1_biasesconv2_weights, and conv2_biases

当我们想再次使用这个模型的时候出现问题了:在两个不同的图片image1和image2上应用以上模型,当然,我们想这两张图片被相同参数的同一个滤波器处理。如果我们两次调用my_image_filter()的话,则会创建两个不同的变量集,每个变量集中各4个变量。

# First call creates one set of 4 variables.result1 = my_image_filter(image1)# Another set of 4 variables is created in the second call.result2 = my_image_filter(image2)

共享变量的一种常见方法是在单独的代码段中创建它们,并将它们传递给使用它们的函数。 例如通过使用字典:

复制代码

variables_dict = {    "conv1_weights": tf.Variable(tf.random_normal([5, 5, 32, 32]),        name="conv1_weights")    "conv1_biases": tf.Variable(tf.zeros([32]), name="conv1_biases")    ... etc. ...}def my_image_filter(input_images, variables_dict):    conv1 = tf.nn.conv2d(input_images, variables_dict["conv1_weights"],        strides=[1, 1, 1, 1], padding='SAME')    relu1 = tf.nn.relu(conv1 + variables_dict["conv1_biases"])    conv2 = tf.nn.conv2d(relu1, variables_dict["conv2_weights"],        strides=[1, 1, 1, 1], padding='SAME')    return tf.nn.relu(conv2 + variables_dict["conv2_biases"])# Both calls to my_image_filter() now use the same variablesresult1 = my_image_filter(image1, variables_dict)result2 = my_image_filter(image2, variables_dict)

复制代码

但是像上面这样在代码外面创建变量很方便, 破坏了封装

  • 构建图形的代码必须记录要创建的变量的名称,类型和形状。
  • 代码更改时,调用者可能必须创建更多或更少或不同的变量。

解决问题的一种方法是使用类创建一个模型,其中类负责管理所需的变量。 一个较简便的解决方案是,使用TensorFlow提供variable scope机制,通过这个机制,可以让我们在构建模型时轻松共享命名变量。

 

如何实现共享变量

tensorflow中的变量共享是通过 tf.variab_scope() 和 tf.get_variable() 来实现的

  • tf.variable_scope(<scope_name>):  管理传递给tf.get_variable()的names的命名空间
  • tf.get_variable(<name>, <shape>, <initializer>): 创建或返回一个给定名字的变量

为了看下tf.get_variable()如何解决以上问题,我们在一个单独的函数里重构创建一个卷积的代码,并命名为conv_relu:

复制代码

def conv_relu(input, kernel_shape, bias_shape):    # Create variable named "weights".    weights = tf.get_variable("weights", kernel_shape,        initializer=tf.random_normal_initializer())    # Create variable named "biases".    biases = tf.get_variable("biases", bias_shape,        initializer=tf.constant_initializer(0.0))    conv = tf.nn.conv2d(input, weights,        strides=[1, 1, 1, 1], padding='SAME')    return tf.nn.relu(conv + biases)

复制代码

此这个函数使用“weights”和“biases”命名变量。 我们希望将它用于conv1和conv2,但变量需要具有不同的名称。 这就是tf.variable_scope()发挥作用的地方:它为各变量分配命名空间。

复制代码

def my_image_filter(input_images):    with tf.variable_scope("conv1"):        # Variables created here will be named "conv1/weights", "conv1/biases".        relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])    with tf.variable_scope("conv2"):        # Variables created here will be named "conv2/weights", "conv2/biases".        return conv_relu(relu1, [5, 5, 32, 32], [32])

复制代码

现在,我们来看下两次调用my_image_filter()会怎样:

result1 = my_image_filter(image1)result2 = my_image_filter(image2)# Raises ValueError(... conv1/weights already exists ...)

如你所见,tf.get_variable() 会检查已存在的变量是不是意外地共享。如果你想共享它们,你需要通过如下设置reuse_variables()来指定它。

转载地址:http://kqlei.baihongyu.com/

你可能感兴趣的文章
设计模式之:单例模式
查看>>
leetcode之数值计算类-----9. Palindrome Number(判断一个数是否为回文数)
查看>>
leetcode之链表类之链表排序-----147/148. 链表快速排序 链表插入排序
查看>>
leetcode之链表类之链表归并类-----OJ 2/21/23/445 链表相加求和 链表归并
查看>>
leetcode之链表逆序翻转类-----92/206 逆序 24/25/61/143 按规则翻转 86/234 双指针分治 19/82/83/203 按规则删除
查看>>
leetcode之深搜递归回溯类-----1/167/653. two sum(记忆化搜索寻找和为给定值的两个数)
查看>>
leetcode之深搜递归回溯类之排列与组合类-----77/39/40/216/317 组合 78/90/368 子排列 22/79/93/131 典型递归回溯 46/47 全排列
查看>>
leetcode之二叉树类之最小公共祖先-----236/235. Lowest Common Ancestor of a Binary/Binary Search Tree
查看>>
leetcode之二叉树类之路径和系列-----112/113/124/257/437 path sum(牵扯附加OJ572和OJ100, 子树和子拓扑)
查看>>
leetcode之二叉树类之二叉树深度系列-----104/111/110/108/109 二叉树最大/最小深度/AVL树的判断和由有序序列生成(牵扯分治相关,OJ105/106)
查看>>
leetcode之二叉树类之二叉树遍历系列-----94/144/145/102/107/103
查看>>
leetcode之二叉树类之二叉树中序遍历运用-----OJ173/230/98/99/285 二叉树迭代器/BST第K小元素/判断BST是否合法/恢复BST/二叉树下个节点
查看>>
leetcode之链表类之相交成环类-----OJ 160/141/142 链表相交 链表环
查看>>
leetcode之数组类之区间类-----OJ 56/57/435/239 重叠区间个数 合并区间 插入区间 滑动窗口最大值
查看>>
leetcode之数组类之数组的旋转与分治类-----OJ 189/33/81/153/154 数组旋转 旋转数组搜索 88 有序数组合并 4 两个有序数组寻找第K个元素/中位数 35 寻找插入位置
查看>>
leetcode之双指针类-----OJ 228/15/16/18/26/80/121/75
查看>>
关于典型的存储引擎及其代表(mysql、redis/memcached、leveldb/rocksdb/hbase系)
查看>>
记 今日头条广告架构社招面试
查看>>
数据结构算法面试总结 序
查看>>
auto关键字实现简易的数值范围迭代器
查看>>