Python迁移学习:用Torchvision、Pytorch进行交通标志图像分类|附代码数据

news/2024/9/20 15:07:18

原文链接:https://tecdat.cn/?p=36539

原文出处:拓端数据部落公众号

本研究旨在探索如何应用迁移学习技术对交通标志图像进行分类。通过构建适用于Torchvision的图像数据集,并利用预训练模型进行微调,我们实现了对原始像素的交通标志图像的分类。此外,我们还引入了一个新的“未知”类别,并对模型进行了重新训练,以提高其在实际应用中的泛化能力。

随着深度学习技术的快速发展,图像分类在交通管理、自动驾驶等领域的应用日益广泛。然而,对于特定的图像分类任务,如交通标志识别,从头开始训练一个深度学习模型往往需要大量的时间和计算资源。因此,迁移学习技术应运而生,它通过利用在大型数据集上预训练的模型,可以大大加快模型的训练速度并提高分类性能。

方法

在本研究中,我们采用了以下步骤来构建和训练交通标志图像分类模型:

  1. 交通标志图像数据集概述:我们首先对所使用的交通标志图像数据集进行了概述,包括数据集的来源、规模、类别分布等信息。
  2. 构建数据集:我们将原始图像数据转换为适用于Torchvision的数据集格式,并进行了必要的数据预处理和增强操作,以提高模型的泛化能力。
  3. 使用Torchvision的预训练模型:我们选择了一个在大型数据集上预训练的深度学习模型作为起点,通过对其进行微调,使其适应交通标志图像的分类任务。
  4. 添加新的“未知”类别并重新训练模型:为了处理实际应用中可能出现的未知类别的图像,我们在数据集中添加了一个新的“未知”类别,并对模型进行了重新训练。通过这种方法,模型可以在遇到未知类别的图像时给出相应的预测结果。

image.png

image.png

配置

 
 
%reload_ext watermark
%watermark -v -p numpy,pandas,torch,torchvision

image.png

 
 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

image.png

交通标志识别

德国交通标志识别基准(GTSRB)包含了超过50,000张带有40多种交通标志注释的图像。给定一张图像,您需要识别出其中的交通标志。

 
 
!unzip -qq GTSRB_Final_Training_Images.zip

image.png

代码模拟

让我们先来了解一下数据。每个交通标志的图像都存储在一个单独的目录中。我们有多少个?

 
 
len(train_folders)

image.png

我们将创建 3 个辅助函数,使用 OpenCV 和 Torchvision 来加载和显示图像:

 
 
def load_image(img_path, resize=True):img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)if resize:img = cv2.resize(img, (64, 64), interpolation = cv2.INTER_AREA)

让我们看看每个交通标志的一些示例:

 
 
sample_images = [np.random.choice(glob(f'{tf}/*ppm')) for tf in train_folders]
show_sign_grid(sample_images)

微信图片_20240621183307.png

这里有一个标志:

 
 
img_path = glob(f'{train_folders[16]}/*ppm')[1]show_image(img_path)

image.png

建立数据集

为了简单起见,我们将重点对一些最常用的交通标志进行分类:

 
 
class_names = ['priority_road', 'give_way', 'stop', 'no_entry']class_indices = [12, 13, 14, 17]

我们将把图像文件复制到一个新的目录中,以便于使用 Torchvision 的数据集助手。让我们从每个类的目录开始:

 
 
for ds in DATASETS:for cls in class_names:(DATA_DIR / ds / cls).mkdir(parents=True,

我们将为每个类别保留 80% 的图像用于训练,10% 用于验证,10% 用于测试。将把每张图片复制到正确的数据集目录下:

 
 
for i, cls_index in enumerate(class_indices):image_paths = np.array(glob(f'{train_folders[cls_index]}/*.ppm'))class_name = class_names[i]

image.png

我们的类别不平衡,但并不严重。我们可以忽略它。

我们将应用一些图像增强技术,人为地增加训练数据集的大小:

 
 
transforms = {'train': T.Compose([T.RandomResizedCrop(size=256),T.RandomRotation(degrees=15),T.RandomHorizontalFlip(),

我们会随机调整大小、旋转和水平翻转。最后,我们使用每个通道的预设值对张量进行归一化处理。

这是 Torchvision 中预训练模型的要求。

我们将为每个图像数据集文件夹和数据加载器创建一个 PyTorch 数据集,以方便训练:

我们还将存储每个数据集中的示例数量和类名,以备日后使用:

 
 
dataset_sizes = {d: len(image_datasets[d]) for d in DATASETS

image.png

让我们来看看一些应用了转换的图像示例。我们还需要反转归一化并重新排列颜色通道,以获得正确的图像数据:

 
 
def imshow(inp, title=None):inp = inp.numpy().transpose((1, 2, 0))mean = np.array([mean_nums])

image.png

使用预训练模型:

我们的模型将接收原始图像像素,并尝试将它们分类为四个交通标志之一。这有多难?试试从头开始建立一个模型。

在这里,我们将使用迁移学习 复制非常流行的ResNet 模型的架构。此外,我们还将使用在 ImageNet 数据集 上训练时学习到的模型权重。Torchvision 让所有这些都变得简单易用:

 
 
def create_model(n_classes):model = models.resnet34(pretrained=True)

除了输出层的变化,我们几乎重复使用了所有内容。这是因为我们数据集中的类数与 ImageNet 不同。

让我们创建一个模型实例:

image.png

训练

我们将编写 3 个辅助函数来封装训练和评估逻辑。首先是 train_epoch

 
 
  loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()

首先,我们将模型调至训练模式,然后查看数据。在得到预测结果后,我们会得到概率最大的类别以及损失,这样我们就能计算出历时损失和准确率。

请注意,我们还使用了学习率调度器。

 
 
    losses.append(loss.item())return correct_predictions.double() / n_examples, np.mean(losses)

除了不进行梯度计算外,对模型的评估非常相似。

让我们把所有东西放在一起:

 
 
  model.load_state_dict(torch.load('best_model_state.bin'))return model, history

我们做了大量的字符串格式化和训练历史记录工作。困难的工作会委托给前面的辅助函数。我们还希望获得最佳模型,因此在训练过程中会存储最准确模型的权重。

让我们来训练第一个模型:

image.png

这里有一个小辅助函数,可以将训练历史可视化:

 
 

plot_training_history(history):fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

image.png

预先训练好的模型非常出色,我们在 3 个历时后获得了非常高的准确率和较低的损失。遗憾的是,我们的验证集太小,无法从中获得一些有意义的指标。

评估

让我们看看测试集中对交通标志的预测:

 
 
def show_predictions(model, class_names, n_images=6):model = model.eval()images_handeled = 0

image.png

即使是几乎看不见的优先道路标志也能正确分类。让我们再深入一点。

我们先从模型中获取预测结果:

image.png

 
 
show_confusion_matrix(cm, class_names)

7176b9f814ec45b59497_1719212735.1162963.png

没有错误。

未见图像分类

好了,但当我们面对真实世界的图像时,我们的模型会有多好呢?让我们来看看:

image.png

 
 
show_image('stop-sign.jpg')

image.png

为此,我们将查看每个类别的置信度。让我们从模型中获取:

 
 
 predict_proba(base_model, 'stop-sign.jpg')

image.png

这有点难以理解。让我们来绘制一下:

 
 
 })sns.barplot(x='values', y='class_names', data=pred_df, orient='h')plt.xlim([0, 1]);

image.png

我们的模型再次表现出色。

分类未知交通标志

我们的模型面临的最后一个挑战是从未见过的交通标志:

image.png

 
 
show_image('unknown-sign.jpg')

image.png

让我们来预测一下:

 
 
predict_proba(base_model, 'unknown-sign.jpg')

image.png

image.png

我们的模型非常确定(超过 95% 的置信度)这是一个让路信号。这显然是错误的。如何才能让你的模型看到这一点呢?

添加 "未知 "类

虽然有多种方法可以处理这种情况,但我们要做的事情更简单。

我们将获取原始数据集中未包含的所有交通标志的索引:

我们将为未知类创建一个新文件夹,并在其中复制一些图像:

 
 
for ds, images in dataset_data:for img_path in images:shutil.copy(img_path, f'{DATA_DIR}/{ds}/unknown/')

接下来的步骤与我们已经做的完全相同:

 
 
class_names = image_datasets['train'].classesdataset_sizes

image.png

image.png

 
 
raining_history(history)

image.png

同样,我们的模型学习速度非常快。让我们再来看看样本图像:

image.png

 
 
prediction_confidence(pred, class_names)

image.png

很好,这个模型并不重视任何已知类别。它不知道这是一个双向符号,但却承认它是未知的。

让我们看看新数据集的一些例子:

image.png

让我们来了解一下这款新车型的性能:

 
 
report(y_test, y_pred, target_names=clas

image.png

image.png

我们的模型依然完美。

总结

您训练了两种不同的模型,用于根据原始像素对交通标志进行分类。

以下是所学到的内容:

  • 交通标志图像数据集概述
  • 建立数据集
  • 使用 Torchvision 预先训练的模型
  • 添加新的未知类并重新训练模型

QQ截图20220608234400.png

Rethinking-reskilling-for-the-post-pandemic-world.jpg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/47491.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

Lampiao靶场实操

本文是基于Vulnhub中的Lampiao靶机的实操Lampiao靶场实操 前言下载靶机解压后,用vm打开即可lampiao靶机地址:https://www.vulnhub.com/entry/lampiao-1,249/ 靶场发布日期:2018年7月28日 目标:Get root!kali:192.168.1.131 靶机:信息收集 打开靶场以及kali,使用kali中的…

算法金 | A - Z,115 个数据科学 机器学习 江湖黑话(全面)

大侠幸会,在下全网同名「算法金」 0 基础转 AI 上岸,多个算法赛 Top 「日更万日,让更多人享受智能乐趣」机器学习本质上和数据科学一样都是依赖概率统计,今天整整那些听起来让人头大的机器学习江湖黑话A - C A/B Testing (A/B 测试) A/B测试是一种在线实验,通过对比测试两…

《编译原理》阅读笔记:p18

《编译原理》学习第 3 天,p18总结,总计 14页。 一、技术总结 1.assembler (1)计算机结构 要想学习汇编的时候更好的理解,要先了解计算机的结构,以下是本人学习汇编时总结的一张图,每当学习汇编时,看到“计数器”,“解码器”,“寄存器”,“数据总线”等概念时,就知道…

QT实现简易串口助手

简易串口助手界面设计 自定义了QWidget的派生类SerialPortWidget,其界面设计大致如下:效果图:项目结构 serial_port└─│ main.cpp //主程序│ mainwindow.cpp //主窗口源文件和头文件│ mainwindow.h│ myintvalidator.cpp //QIntValidator的派生类,重写fixup…

GraqphQL 学习

GraphQL是Graph+QL。Graph是图,描述数据最好的方式是图数据结构(包括树),数据和数据之间,有像图一样的联系,以图的思维来考虑数据。QL是query language,像写query语句一样请求数据,query什么数据,就返回什么数据。怎样用图的方式来描述数据?定义Schema(类型), type 类…

Android :安卓学习笔记之 Handler机制 的简单理解和使用

目录Handler机制1、Handler使用的引出2、背景和定义3、作用和意义4、主要参数5、工作原理及流程5.1、对应关系6、深入分析 Handler机制源码6.1、Handler机制的核心类6.2、核心方法6.3、方式1:使用 Handler.sendMessage()6.3.1、 创建Handler类对象6.3.1.1、隐式操作1:创建循…

try_catch处理异常

try_catch使用不影响其他功能的使用 这点比无脑throws强 public class Demo06Exception { public static void main(String[] args){ String s="a.txt"; try{ add(s); } catch (FileNotFoundException e){ System.out.println(e); } delete(); updata(); find(); } …

麒麟v10 SP2系统容器化部署MySQL、RabbitMQ、redis-ha-haproxy等出现OOMkill内存异常升高问题处理

问题场景 操作系统:麒麟系统v10 SP2 k8s版本:v1.23.17 容器运行时:containerd Rabbitmq镜像版本:3.9.11-debian-10-r0(3.11.10-debian-11-r0版本已正常) Mysql镜像版本:mysql_5.7.37-debian-10-r95(8.0.20版本已正常) redis-ha-haproxy镜像版本:haproxy:2.0.22-alpin…