阅读(2826)
赞(12)
TensorFlow 指标(contrib)
2017-08-23 15:33:09 更新
评估指标和汇总统计的操作.
API
该模块提供计算流式指标的函数:以动态价值计算的指标 Tensors。每个指标声明返回一个“value_tensor”,一个返回指标当前值的幂等运算,还有一个“update_op”,这是一个从当前Tensors 测量值累加信息的操作,并返回 “ value_tensor”。
要使用这些指标中的任何一个,只需声明指标,update_op 重复调用即可将数据累加到所需数量的 Tensor 值(通常每个都是单个批次),最后评估 value_tensor。例如,要使用streaming_mean:
value = ...
mean_value, update_op = tf.contrib.metrics.streaming_mean(values)
sess.run(tf.local_variables_initializer())
for i in range(number_of_batches):
print('Mean after batch %d: %f' % (i, update_op.eval())
print('Final Mean: %f' % mean_value.eval())
每个指标函数将节点添加到图表中,该图保持计算指标值所需的状态以及实际执行计算的一组操作。每个指标评估由三个步骤组成
- 初始化:初始化指标标准状态
- 聚合:更新指标标准状态的值
- 完成:计算最终的指标值
在上面的例子中,调用 streaming_mean 创建一对状态变量,它们将包含(1)运行总和和(2)总和中样本数的计数。因为流式指标使用局部变量,所以初始化阶段通过运行tf.local_variables_initializer() 返回的操作来执行的。它将 sum 和 count 变量设置为零。
接下来,通过 values 适当地检查状态变量的当前状态和递增来执行聚合.此步骤通过运行由指标数返回的 update_op 来执行。
最后,通过评估 “value_tensor”
实际上,我们通常希望评估多个批次和多个指标。为此,我们只需要多次运行指标计算操作:
labels = ...
predictions = ...
accuracy, update_op_acc = tf.contrib.metrics.streaming_accuracy(
labels, predictions)
error, update_op_error = tf.contrib.metrics.streaming_mean_absolute_error(
labels, predictions)
sess.run(tf.local_variables_initializer())
for batch in range(num_batches):
sess.run([update_op_acc, update_op_error])
accuracy, error = sess.run([accuracy, error])
请注意,在不同输入上多次评估相同的指标时,必须指定每个指标的范围,以避免将结果累加在一起:
labels = ...
predictions0 = ...
predictions1 = ...
accuracy0 = tf.contrib.metrics.accuracy(labels, predictions0, name='preds0')
accuracy1 = tf.contrib.metrics.accuracy(labels, predictions1, name='preds1')
某些指标(如 streaming_mean 或 streaming_accuracy)可以通过 weights 参数加权。权重张量的大小必须作为指标的加权平均值的标签,并且要和预测张量和结果相同。
TensorFlow 指标“操作”
- tf.contrib.metrics.streaming_accuracy
- tf.contrib.metrics.streaming_mean
- tf.contrib.metrics.streaming_recall
- tf.contrib.metrics.streaming_recall_at_thresholds
- tf.contrib.metrics.streaming_precision
- tf.contrib.metrics.streaming_precision_at_thresholds
- tf.contrib.metrics.streaming_auc
- tf.contrib.metrics.streaming_recall_at_k
- tf.contrib.metrics.streaming_mean_absolute_error
- tf.contrib.metrics.streaming_mean_iou
- tf.contrib.metrics.streaming_mean_relative_error
- tf.contrib.metrics.streaming_mean_squared_error
- tf.contrib.metrics.streaming_mean_tensor
- tf.contrib.metrics.streaming_root_mean_squared_error
- tf.contrib.metrics.streaming_covariance
- tf.contrib.metrics.streaming_pearson_correlation
- tf.contrib.metrics.streaming_mean_cosine_distance
- tf.contrib.metrics.streaming_percentage_less
- tf.contrib.metrics.streaming_sensitivity_at_specificity
- tf.contrib.metrics.streaming_sparse_average_precision_at_k
- tf.contrib.metrics.streaming_sparse_precision_at_k
- tf.contrib.metrics.streaming_sparse_precision_at_top_k
- tf.contrib.metrics.streaming_sparse_recall_at_k
- tf.contrib.metrics.streaming_specificity_at_sensitivity
- tf.contrib.metrics.streaming_concat
- tf.contrib.metrics.streaming_false_negatives
- tf.contrib.metrics.streaming_false_negatives_at_thresholds
- tf.contrib.metrics.streaming_false_positives
- tf.contrib.metrics.streaming_false_positives_at_thresholds
- tf.contrib.metrics.streaming_true_negatives
- tf.contrib.metrics.streaming_true_negatives_at_thresholds
- tf.contrib.metrics.streaming_true_positives
- tf.contrib.metrics.streaming_true_positives_at_thresholds
- tf.contrib.metrics.auc_using_histogram
- tf.contrib.metrics.accuracy
- tf.contrib.metrics.aggregate_metrics
- tf.contrib.metrics.aggregate_metric_map
- tf.contrib.metrics.confusion_matrix
TensorFlow 设置“操作”
- tf.contrib.metrics.set_difference
- tf.contrib.metrics.set_intersection
- tf.contrib.metrics.set_size
- tf.contrib.metrics.set_union