首页 > 编程 > Python > 正文

Pytorch 实现计算分类器准确率(总分类及子分类)

2020-02-15 21:29:07
字体:
来源:转载
供稿:网友

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()total = torch.zeros(1).squeeze().cuda()for i, (images, labels) in enumerate(train_loader):      images = Variable(images.cuda())      labels = Variable(labels.cuda())      output = model(images)      prediction = torch.argmax(output, 1)      correct += (prediction == labels).sum().float()      total += len(labels)acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))total = list(0. for i in range(args.class_num))for i, (images, labels) in enumerate(train_loader):      images = Variable(images.cuda())      labels = Variable(labels.cuda())      output = model(images)      prediction = torch.argmax(output, 1)      res = prediction == labels      for label_idx in range(len(labels)):        label_single = label[label_idx]        correct[label_single] += res[label_idx].item()        total[label_single] += 1 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total)) for acc_idx in range(len(train_class_correct)):      try:        acc = correct[acc_idx]/total[acc_idx]      except:        acc = 0      finally:        acc_str += '/tclassID:%d/tacc:%f/t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持武林站长站。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表