您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

如何使用Tensorflow创建预测和地面真实标签的混淆矩阵?

如何使用Tensorflow创建预测和地面真实标签的混淆矩阵?

这段代码我有用。我自己整理一下:)

from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import classification_report

def print_confusion_matrix(plabels,tlabels):
"""
    functions print the confusion matrix for the different classes
    to find the error...

    Input:
    -----------
    plabels: predicted labels for the classes...
    tlabels: true labels for the classes

    code from: http://stackoverflow.com/questions/2148543/how-to-write-a-confusion-matrix-in-python
"""
import pandas as pd
plabels = pd.Series(plabels)
tlabels = pd.Series(tlabels)

# draw a cross tabulation...
df_confusion = pd.crosstab(tlabels,plabels, rownames=['Actual'], colnames=['Predicted'], margins=True)

#print df_confusion
return df_confusion

def confusionMatrix(text,Labels,y_pred, not_partial):
    y_actu = np.where(Labels[:]==1)[1]
    df = print_confusion_matrix(y_pred,y_actu)
    print "\n",df
    #print plt.imshow(df.as_matrix())
    if not_partial:
       print "\n",classification_report(y_actu, y_pred)
    print "\n\t------------------------------------------------------\n"

def do_eval(message, sess, correct_prediction, accuracy, pred, X_, y_,x,y):
    predictions = sess.run([correct_prediction], Feed_dict={x: X_, y: y_})
    prediction  = tf.argmax(pred,1)
    labels = prediction.eval(Feed_dict={x: X_, y: y_}, session=sess)
    print message, accuracy.eval({x: X_, y: y_}),"\n"
    confusionMatrix("Partial Confusion matrix",y_,predictions[0], False)#Partial confusion Matrix
    confusionMatrix("Complete Confusion matrix",y_,labels, True) #complete confusion Matrix

# Launch the graph
with tf.Session() as sess:
sess.run(init)
data = zip(X_train,y_train)
data = np.array(data)
data_size = len(data)
num_batches_per_epoch = int(len(data)/batch_size) + 1
for epoch in range(training_epochs):
    avg_cost = 0.
    # Shuffle the data at each epoch
    shuffle_indices = np.random.permutation(np.arange(data_size))
    shuffled_data = data[shuffle_indices]
    for batch_num in range(num_batches_per_epoch):
        start_index = batch_num * batch_size
        end_index = min((batch_num + 1) * batch_size, data_size)
        sample = zip(*shuffled_data[start_index:end_index])
        #picking up random batches from training set of specific size
        batch_xs, batch_ys = sample[0],sample[1]
        # Fit training using batch data
        sess.run(optimizer, Feed_dict={x: batch_xs, y: batch_ys})
        # Compute average loss
        avg_cost += sess.run(cost, Feed_dict={x: batch_xs, y: batch_ys})/num_batches_per_epoch
    #append loss
    loss_history.append(avg_cost)

    # Display logs per epoch step
    if (epoch % display_step == 0):           
        correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))            
        # Calculate training  accuracy
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        trainAccuracy = accuracy.eval({x: X_train, y: y_train})
        train_acc_history.append(trainAccuracy)           
        # Calculate validation  accuracy
        valAccuracy = accuracy.eval({x: X_val, y: y_val})
        val_acc_history.append(valAccuracy) 
        print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost), "train=",trainAccuracy,"val=", valAccuracy

print "Optimization Finished!\n"

# Evaluation of  model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) 
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

do_eval("Accuracy of Gold Test set Results: ", sess, correct_prediction, accuracy, pred, X_gold, y_gold, x, y)

这是示例输出

Accuracy of Gold Test set Results:  0.642608


Predicted  False  True  All
Actual                     
0             20    46   66
1              3     1    4
2             21     1   22
3              8     4   12
4             16     7   23
5             54   259  313
6             41    14   55
7             11     2   13
8             48    94  142
9             29     4   33
10            17     4   21
11            39   116  155
All          307   552  859

Predicted   0  1  2   3   4    5   6   7    8   9  10   11  All
Actual                                                         
0          46  0  0   0   0    8   0   2    2   2   0    6   66
1           0  1  0   1   0    2   0   0    0   0   0    0    4
2           3  0  1   3   0   12   0   0    1   0   0    2   22
3           2  0  0   4   1    3   1   1    0   0   0    0   12
4           1  0  0   0   7   12   0   0    1   0   0    2   23
5           8  0  0   1   5  259   9   0    9   3   1   18  313
6           1  0  0   1   6   30  14   1    2   0   0    0   55
7           3  0  0   0   0    2   0   2    4   0   1    1   13
8           6  0  0   1   1   18   0   3   94   8   1   10  142
9           9  0  0   0   0    1   1   1    9   4   0    8   33
10          1  0  0   0   3    6   0   1    1   0   4    5   21
11          5  1  0   1   0   18   1   0    6   5   2  116  155
All        85  2  1  12  23  371  26  11  129  22   9  168  859

         precision    recall  f1-score   support

      0       0.54      0.70      0.61        66
      1       0.50      0.25      0.33         4
      2       1.00      0.05      0.09        22
      3       0.33      0.33      0.33        12
      4       0.30      0.30      0.30        23
      5       0.70      0.83      0.76       313
      6       0.54      0.25      0.35        55
      7       0.18      0.15      0.17        13
      8       0.73      0.66      0.69       142
      9       0.18      0.12      0.15        33
     10       0.44      0.19      0.27        21
     11       0.69      0.75      0.72       155

     avg / total       0.64      0.64      0.62       859
其他 2022/1/1 18:26:21 有484人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶