首页 > 编程 > Python > 正文

TensorFlow中权重的随机初始化的方法

2020-01-04 15:52:27
字体:
来源:转载
供稿:网友

一开始没看懂stddev是什么参数,找了一下,在tensorflow/python/ops里有random_ops,其中是这么写的:

def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,         seed=None, name=None): """Outputs random values from a normal distribution. Args:  shape: A 1-D integer Tensor or Python array. The shape of the output tensor.  mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal   distribution.  stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation   of the normal distribution.  dtype: The type of the output.  seed: A Python integer. Used to create a random seed for the distribution.   See   [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)   for behavior.  name: A name for the operation (optional). Returns:  A tensor of the specified shape filled with random normal values. """

也就是按照正态分布初始化权重,mean是正态分布的平均值,stddev是正态分布的标准差(standard deviation),seed是作为分布的random seed(随机种子,我百度了一下,跟什么伪随机数发生器还有关,就是产生随机数的),在mnist/concolutional中seed赋值为66478,挺有意思,不知道是什么原理。

后面还有truncated_normal的定义:

def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,           seed=None, name=None): """Outputs random values from a truncated normal distribution. The generated values follow a normal distribution with specified mean and standard deviation, except that values whose magnitude is more than 2 standard deviations from the mean are dropped and re-picked. Args:  shape: A 1-D integer Tensor or Python array. The shape of the output tensor.  mean: A 0-D Tensor or Python value of type `dtype`. The mean of the   truncated normal distribution.  stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation   of the truncated normal distribution.  dtype: The type of the output.  seed: A Python integer. Used to create a random seed for the distribution.   See   [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)   for behavior.  name: A name for the operation (optional). Returns:  A tensor of the specified shape filled with random truncated normal values. """

截断正态分布,以前都没听说过。

TensorFlow还提供了平均分布等。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持VEVB武林网。


注:相关教程知识阅读请移步到python教程频道。
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表