阅读(2907) (11)

tf.one_hot函数:返回one-hot张量

2017-11-07 17:39:54 更新
tf.one_hot 函数
one_hot(
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None
)

定义在:tensorflow/python/ops/array_ops.py.

参见指南:张量变换>分割和连接

返回一个 one-hot 张量.

索引中由索引表示的位置取值 on_value,而所有其他位置都取值 off_value.

on_value 和 off_value必须具有匹配的数据类型.如果还提供了 dtype,则它们必须与 dtype 指定的数据类型相同.

如果未提供 on_value,则默认值将为 1,其类型为 dtype.

如果未提供 off_value,则默认值为 0,其类型为 dtype.

如果输入的索引的秩为 N,则输出的秩为 N+1.新的坐标轴在维度上创建 axis(默认值:新坐标轴在末尾追加).

如果索引是标量,则输出形状将是长度 depth 的向量.

如果索引是长度 features 的向量,则输出形状将为:

features x depth if axis == -1
depth x features if axis == 0

如果索引是具有形状 [batch, features] 的矩阵(批次),则输出形状将是: 

batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0

如果 dtype 没有提供,则它会尝试假定 on_value 或者 off_value 的数据类型,如果其中一个或两个都传入.如果没有提供 on_value、off_value 或 dtype,则dtype 将默认为值 tf.float32.

注意:如果一个非数值数据类型输出期望(tf.string,tf.bool等),都on_value与off_value 必须被提供给one_hot.

示例

示例-1

假设如下:

indices = [0, 2, -1, 1]
depth = 3
on_value = 5.0
off_value = 0.0
axis = -1

那么输出为 [4 x 3]:

output =
[5.0 0.0 0.0]  // one_hot(0)
[0.0 0.0 5.0]  // one_hot(2)
[0.0 0.0 0.0]  // one_hot(-1)
[0.0 5.0 0.0]  // one_hot(1)

示例-2

假设如下:

indices = [[0, 2], [1, -1]]
depth = 3
on_value = 1.0
off_value = 0.0
axis = -1

那么输出是 [2 x 2 x 3]:

output =
[
  [1.0, 0.0, 0.0]  // one_hot(0)
  [0.0, 0.0, 1.0]  // one_hot(2)
][
  [0.0, 1.0, 0.0]  // one_hot(1)
  [0.0, 0.0, 0.0]  // one_hot(-1)
]

使用 on_value 和 off_value 的默认值:

indices = [0, 1, 2]
depth = 3

输出将是:

output =
[[1., 0., 0.],
 [0., 1., 0.],
 [0., 0., 1.]]

参数:

  • indices:指数的张量.
  • depth:一个标量,用于定义一个 one hot 维度的深度.
  • on_value:定义在 indices[j] = i 时填充输出的值的标量.(默认:1)
  • off_value:定义在 indices[j] != i 时填充输出的值的标量.(默认:0)
  • axis:要填充的轴(默认:-1,一个新的最内层轴).
  • dtype:输出张量的数据类型.

返回值:

  • output: one-hot 张量.

可能引发的异常:

  • TypeError:如果 on_value 或者 off_value 的类型不匹配 dtype.
  • TypeError:如果 on_value 和 off_value 的 dtype 不匹配.