使用cifar100上训练的resnet18进行ood测试

news/2024/9/28 11:09:39

以cifar100作为闭集(closed-set)数据集,使用resnet18模型进行训练,然后在常见的开集(out-of-distribution)数据集上进行OOD检测。使用MSP(Maximum Softmax Probability)作为OOD检测的依据。

开集噪声数据集使用gaussian, rademacher, blob, svhn四种类型。其中gaussian、rademacher、blob是生成的随机噪声,svhn是额外引入的噪声数据集。

输出结果

Error Rate 46.3000
AUROC: 81.9790, AUPR: 85.7377, FPR95: 73.3909
ood type: gaussian
AUROC: 68.1596, AUPR: 92.9277, FPR95: 99.4000
ood type: rademacher
AUROC: 69.9099, AUPR: 93.1788, FPR95: 96.1500
ood type: blob
AUROC: 68.0615, AUPR: 92.7477, FPR95: 97.5500
ood type: svhn
AUROC: 66.9684, AUPR: 91.6508, FPR95: 89.0500

可以看到,在使用简单的交叉熵损失且不经过其他处理的resnet18,在开集检测上的表示并不好。

闭集数据集上训练一个resnet18

# train.py
import torch
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.datasets.cifar import CIFAR100
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from sklearn.metrics import accuracy_score
import torch.nn.functional as Fdef get_transform(train=True):mean = [0.4914, 0.4822, 0.4465]std = [0.2023, 0.1994, 0.2010]if train:transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean, std),])else:transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std),])return transformdef get_loader(train=True):transform = get_transform(train)dataset = CIFAR100(root='~/data', train=train, transform=transform)loader = DataLoader(dataset, batch_size=128, shuffle=train, num_workers=8, pin_memory=True)return loaderdef train_model():loader = get_loader(train=True)test_loader = get_loader(train=False)model = resnet18(num_classes=100)model = model.cuda()epochs = 100optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)scheduler = MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)model.eval()all_preds = []all_labels = []for epoch in range(epochs):model.train()print('Training')for i, (inputs, labels) in enumerate(loader):inputs, labels = inputs.cuda(), labels.cuda()outputs = model(inputs)loss = F.cross_entropy(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if i % 100 == 0:print(f'Epoch[{epoch}] Iter: {i}/{len(loader)} Loss: {loss.item()}')scheduler.step()print('Testing')for inputs, labels in test_loader:inputs = inputs.cuda()outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)all_preds.extend(predicted.cpu().numpy())all_labels.extend(labels.numpy())accuracy = accuracy_score(all_labels, all_preds)print(f'Epoch[{epoch}] acc@1 {accuracy:.4f}')torch.save(model.state_dict(), 'cifar100_resnet18.pth')if __name__ == '__main__':train_model()

构建常用的开集数据集

# ood_data.py
import torch
import numpy as np
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.dataloader import DataLoader
from skimage.filters import gaussian
from torchvision.datasets import SVHNfrom train import get_transformdef build_ood_loader(noise_type, ood_num_examples, batch_size, worker):dummy_targets = torch.ones(ood_num_examples)if noise_type in ['gaussian', 'rademacher', 'blob']:if noise_type == 'gaussian':ood_data = torch.from_numpy(np.float32(np.clip(np.random.normal(size=(ood_num_examples, 3, 32, 32), scale=0.5), -1, 1)))elif noise_type == 'rademacher':ood_data = torch.from_numpy(np.random.binomial(n=1, p=0.5, size=(ood_num_examples, 3, 32, 32)).astype(np.float32)) * 2 - 1else:ood_data = np.float32(np.random.binomial(n=1, p=0.7, size=(ood_num_examples, 32, 32, 3)))for i in range(ood_num_examples):ood_data[i] = gaussian(ood_data[i], sigma=1.5)ood_data[i][ood_data[i] < 0.75] = 0.0ood_data = torch.from_numpy(ood_data.transpose((0, 3, 1, 2))) * 2 - 1dataset = TensorDataset(ood_data, dummy_targets)elif noise_type == 'svhn':transform = get_transform(train=False)dataset = SVHN(root='~/data/svhn', split='test', transform=transform, download=True)data = dataset.data[:ood_num_examples]dataset.data = dataelse:raise ValueError(f'Unknown noise type: {noise_type}')dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,num_workers=worker, pin_memory=True)return dataloader

使用常见的OOD检测评估指标

# ood_utils.py
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve@torch.no_grad()
def get_ood_scores(model, dataloader, closed_set=False):model.eval()scores = []right_scores = []wrong_scores = []for i, (data, targets) in enumerate(dataloader):data = data.cuda()output = model(data)smax = F.softmax(output, dim=1).cpu().numpy()scores.append(np.max(smax, axis=1))if closed_set:pred = np.argmax(smax, axis=1)targets = targets.numpy().squeeze()right_indices = pred == targetswrong_indices = np.invert(right_indices)right_scores.append(np.max(smax[right_indices], axis=1))wrong_scores.append(np.max(smax[wrong_indices], axis=1))if closed_set:return (np.concatenate(scores),np.concatenate(right_scores),np.concatenate(wrong_scores))else:return np.concatenate(scores)def get_performance(pos, neg):pos = np.array(pos).reshape(-1)neg = np.array(neg).reshape(-1)scores = np.concatenate([pos, neg])labels = [1] * len(pos) + [0] * len(neg)auroc = roc_auc_score(labels, scores)aupr = average_precision_score(labels, scores)fpr, tpr, _ = roc_curve(labels, scores)fpr95 = fpr[np.argmax(tpr >= 0.95)]return auroc, aupr, fpr95def show_performance(pos, neg):auroc, aupr, fpr95 = get_performance(pos, neg)print(f"AUROC: {auroc * 100:.4f}, AUPR: {aupr * 100:.4f}, FPR95: {fpr95 * 100:.4f}")

测试模型的OOD检测性能

# test.py
import torch
from torchvision.models import resnet18from train import get_loader
from ood_utils import get_ood_scores, show_performance
from ood_data import build_ood_loaderdef evaluate():model = resnet18(num_classes=100)model.load_state_dict(torch.load('cifar100_resnet18.pth'))model = model.cuda()# closed-set testtest_loader = get_loader(train=False)in_score, right_score, wrong_score = get_ood_scores(model, test_loader, True)num_right, num_wrong = len(right_score), len(wrong_score)print(f'Error Rate {100 * num_wrong / (num_right + num_wrong):.4f}')show_performance(right_score, wrong_score)# open-set testood_num_examples = len(test_loader.dataset) // 5ood_types = ['gaussian', 'rademacher', 'blob', 'svhn']for i in ood_types:print(f'ood type: {i}')ood_loader = build_ood_loader(i, ood_num_examples, batch_size=128, worker=8)out_score = get_ood_scores(model, ood_loader)show_performance(in_score, out_score)if __name__ == '__main__':evaluate()

依赖

scikit-learn       1.5.2
scipy              1.14.1
torch              2.4.1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/65707.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

京东面试:RR隔离mysql如何实现?什么情况RR不能解决幻读?

文章很长,且持续更新,建议收藏起来,慢慢读!疯狂创客圈总目录 博客园版 为您奉上珍贵的学习资源 : 免费赠送 :《尼恩Java面试宝典》 持续更新+ 史上最全 + 面试必备 2000页+ 面试必备 + 大厂必备 +涨薪必备 免费赠送 :《尼恩技术圣经+高并发系列PDF》 ,帮你 实现技术自由,…

固态硬盘接入电脑没有反应

当固态硬盘(SSD)接入电脑后没有反应时,可能由多种原因造成。以下是一些常见的原因及其解决方法: 一、物理连接问题 检查接口连接: 确保SSD的SATA接口或M.2接口(视SSD类型而定)与主板连接牢固,没有松动或错位。 检查SATA数据线或M.2插槽是否损坏,如有必要,更换新的数据…

一些超好用的 GitHub 插件和技巧

聊聊我平时使用 GitHub 时学到的一些插件、技巧。聊聊我平时使用 GitHub 时学到的一些插件、技巧。 ‍ ‍ 浏览器插件 在我的另一篇博客 浏览器插件推荐 里提到过跟 GitHub 相关的一些插件,这里重复下:Sourcegraph:在线打开项目,方便阅读,将 GitHub 变得和 IDE 一般,集成…

java第一次正式课程课后习题

s和t并非引用同一对象,不同的值引用不同对象,相同值引用相同对象。 枚举类型并非原始数据类型,而是引用数据类型。 采用.velueof和.从枚举类型中赋值效果相同。java中的数采用补码形式表示。由示例可知,局部变量与全局变量重名时会在局部屏蔽全局变量,采用局部变量。java中…

橡胶 在经历大C浪的反弹

下跌部分 开启大ABC的反弹:

九月二十八

以下代码的输出结果是什么? int X=100; int Y=200; System.out.println("X+Y="+X+Y); System.out.println(X+Y+"=X+Y"); 为什么会有这样的输出结果? 输出结果是: X+Y=100200 100200=X+Y 出现这样的输出结果是因为在Java中,当多个值连接在一起时,会根据…

九月二十七2

当需要处理非常大或非常小的数值时,应选择float或double类型。 当需要处理字符或需要较大范围的无符号整数时,应选择char类型。 当需要在内存和处理速度之间做出权衡时,可以根据需要选择适当的整数类型(byte, short, int, long)。 对于需要精确计算的场景,应避免使用浮点…

PARTVI-Oracle数据库管理与开发-数据库管理员和开发人员的主题

17.数据库管理员和开发人员的主题 17.1. 数据库安全概述 通常情况下,数据库安全涉及用户认证、加密、访问控制和监控。 17.1.1. 用户账户 每个Oracle数据库都有一个有效数据库用户的列表。数据库包含几个默认账户,包括默认的管理员账户SYSTEM(参见第2-5页的“SYS和SYSTEM模式…