OpenCV + sklearnSVM 实现手写数字分割和识别

news/2024/10/2 22:25:48

这学期机器学习考核方式以大作业的形式进行考核,而且只能使用一些传统的机器学习算法。
综合再三,选择了自己比较熟悉的MNIST数据集以及OpenCV来完成手写数字的分割和识别作为大作业。

1. 数据集准备

MNIST数据集是一个手写数字的数据库,包含60000张训练图片和10000张测试图片,每张图片大小为28x28像素,每张图片都是一个
灰度图,像素取值范围在0-255之间。

这里使用pytorch的torchvision.datasets模块来读取MNIST数据集。

from torchvision import datasets
mnist_set = datasets.MNIST(root="./MNIST", train=True, download=True)

具体参数说明请自行搜索。注意若donwload=True,则torchvision会通过内置链接自动下载数据集,
但是有时会失效。因此可以自己去网络上下载并解压后排列成指定文件树,如下

MNIST
├── MNSIT
│   ├── raw
│   │   ├── t10k-images-idx3-ubyte.gz
│   │   ├── t10k-labels-idx1-ubyte.gz
│   │   ├── train-images-idx3-ubyte.gz
│   │   └── train-labels-idx1-ubyte.gz

然后使用如下语句去读取数据集

img, target = minst_set[0]

其中每个img类型为PILimage,target类型为int,代表该图片对应的数字。

但是在喂给SVM训练时需要的是[batch_size, data]大小的numpy数组,因此需要做一些预处理

   x_, y_ = list(zip(*([(np.array(img).reshape(28*28), target) for img, target in mnist_set])))

上面的语句实现了将MNIST数据集转换成numpy数组的形式,其中x_是每个成员为[1, 784]的numpy数组,y_为对应的数字所组成的列表。

2. SVM训练

支持向量机(support vector machine,SVM)是经典的机器学习算法,其通过选取两个n维支持向量(support vector)之间的n维超平面来对两类对象进行二分类。而专注于分类的SVM又称作Support Vector Classification,SVC。

求解SVM是一个很复杂的问题,但是万幸的是sklearn中有封装的很好的模块,可以很简单的直接使用

from sklearn.svm import SVCsvc = SVC(kernel='rbf', C=1)svc.fit(x_, y_)

其中fit接口接受两个参数,第一个参数为训练数据[batch_size, data],第二个参数为训练标签[batch_size,1]。
SVC的构造函数如下

SVC(C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape='ovr', random_state=None)

比较重要的参数有kernel,C,decison_function_shape等。

  • kernel参数指定了核函数,常用的有linear,poly,rbf,sigmoid等。
  • C为惩罚系数,C越大,对误分类的惩罚越大,模型越保守,C越小,对误分类的惩罚越小,模型越宽松,也就是较大的C在训练集上会有更高的正确率,较小的C会容许噪声的存在,泛化能力较强。
  • decision_function_shape参数指定了决策函数的形状,ovr表示one-vs-rest,ovo表示one-vs-one,具体的意思可以网络查阅

4. 数字分割

数字分割是指将图像中的数字部分分割出来,然后一个一个喂给SVM进行分类

这里就是使用opencv对拍摄的图像进行轮廓提取后拟合外接矩形,借此来提取数字部分的ROI。

这里选择进行Canny边缘检测后去进行轮廓提取,然后拟合外接矩形,因为相较于直接二值化后去提取数字部分的ROI,
边缘检测对数字与纸张的边界更加敏感,即便在光照不均匀的情况下,也能较好的提取出数字的边缘。鲁棒性强。

5. 杂项与代码

这里还有一些杂项,比如保存模型,加载模型

使用pickle模块对训练好的模型对象进行序列化保存与加载,可以将训练好的模型保存到本地,以便后续使用。

最后贴出代码

代码
import os.path
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision import datasets
from torchvision import transforms
from sklearn import svm
from sklearn import preprocessing
from sklearnex import patch_sklearn
import pickle
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import learning_curve'''@brief  加载MNIST数据集并转换格式成二值图@param train: 是否为训练集@param data_enhance: 是否进行数据增强@return 二值图集和标签集
'''
def LoadMnistDataset(train=True, data_enhance=False):mnist_set = datasets.MNIST(root="./MNIST", train=train, download=True)x_, y_ = list(zip(*([(np.array(img), target) for img, target in mnist_set])))sets_raw = []sets_r20 = []sets_invr20 = []y = []y_r20 = []y_invr20 = []sets = []matrix_r20 = cv2.getRotationMatrix2D((14, 14), 25, 1.0)matrix_invr20 = cv2.getRotationMatrix2D((14, 14), -25, 1.0)select = 0for idx in range(len(x_)):# 对图像进行二值化以及数据增强_, img = cv2.threshold(x_[idx], 255, 255, cv2.THRESH_OTSU)sets_raw.append(np.array(img.data).reshape(784))y.append(y_[idx])if data_enhance:if select % 2 == 0:img_r20 = ~cv2.warpAffine(~img, matrix_r20, (28, 28), borderValue=(255, 255, 255))sets_r20.append(np.array(img_r20.data).reshape(784))y_r20.append(y_[idx])else:img_invr20 = ~cv2.warpAffine(~img, matrix_invr20, (28, 28), borderValue=(255, 255, 255))sets_invr20.append(np.array(img_invr20.data).reshape(784))y_invr20.append(y_[idx])select += 1# 数据增强sets = sets_raw + sets_r20 + sets_invr20sets = np.array(sets)print(sets.shape)if data_enhance:y = y + y_r20 + y_invr20return sets, y'''@brief  保存SVM模型@param svc_model: SVM模型 @param file_path: 模型保存路径,默认为./SVC@return None
'''
def SaveSvcModel(svc_model, file_path="./SVC"):with open(file_path, 'wb') as fs:pickle.dump(svc_model, fs)'''@brief  加载SVM模型@param file_path: 模型保存路径,默认为./SVC@return SVM模型
'''
def LoadSvcModel(file_path="./SVC"):if not os.path.exists(file_path):assert "Model Do Not Exist"with open(file_path, 'rb') as fs:svc_model = pickle.load(fs)return svc_model'''@brief  训练SVM模型@param c: SVM参数C@param enhance: 是否进行数据增强@return acc: 在测试集上的准确率svc_model: SVM模型
'''
def TrainSvc(c, enhance):# 读取数据集,训练集及测试集images_train, targets_train = LoadMnistDataset(train=True, data_enhance=enhance)images_test, targets_test = LoadMnistDataset(train=False, data_enhance=enhance)# 训练svc_model = svm.SVC(C=c,kernel='rbf', decision_function_shape='ovr')svc_model.fit(images_train, targets_train)# 在测试集上测试准确度res = svc_model.predict(images_test)correct = (res == targets_test).sum()accuracy = correct / len(images_test)print(f"测试集上的准确率为{accuracy * 100}%")return svc_model'''@brief  预处理比较粗的字体@param image: 输入图像@:param show: 是否显示预处理后的图像@:param thresh: 二值化阈值@return 预处理后的图像数据
'''
def PreProcessFatFont(image, show=False):# 白底黑字转黑底白字pre_ = ~image# 转单通道灰度pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)# 二值化_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)# resize后添加黑色边框,亲测可提高识别率pre_ = cv2.resize(pre_, (112, 112))_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)back = np.zeros((300, 300), np.uint8)back[29:141, 29:141] = pre_pre_ = backif show:cv2.imshow("show", pre_)cv2.waitKey(0)# 做一次开运算(腐蚀 + 膨胀)kernel = np.ones((2, 2), np.uint8)pre_ = cv2.erode(pre_, kernel, iterations=1)kernel = np.ones((3, 3), np.uint8)pre_ = cv2.dilate(pre_, kernel, iterations=1)# 第二次resizepre_ = cv2.resize(pre_, (56, 56))_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)# 做一次开运算(腐蚀 + 膨胀)kernel = np.ones((2, 2), np.uint8)pre_ = cv2.erode(pre_, kernel, iterations=1)kernel = np.ones((3, 3), np.uint8)pre_ = cv2.dilate(pre_, kernel, iterations=1)# resize成输入规格pre_ = cv2.resize(pre_, (28, 28))_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)# 转换为SVM的输入格式pre_ = np.array(pre_).flatten().reshape(1, -1)return pre_'''@brief  预处理细的字体@param image: 输入图像@param show: 是否显示预处理后的图像@param thresh: 二值化阈值@return 预处理后的图像数据
'''
def PreProcessThinFont(image, show=False):# 白底黑字转黑底白字pre_ = ~image# 转灰度图pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)# 增加黑色边框pre_ = cv2.resize(pre_, (112, 112))_, pre_ = cv2.threshold(pre_,thresh=0, maxval=255, type=cv2.THRESH_OTSU)back = np.zeros((170, 170), dtype=np.uint8) # 这里不指明类型会导致后续矩阵强转为float64,无法使用大津法阈值back[29:141, 29:141] = pre_pre_ = backif show:cv2.imshow("show", pre_)cv2.waitKey(0)# 对细字体先膨胀一下kernel = np.ones((3, 3), np.uint8)pre_ = cv2.dilate(pre_, kernel, iterations=2)# 第二次resizepre_ = cv2.resize(pre_, (56, 56))_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)# 做一次开运算(腐蚀 + 膨胀)kernel = np.ones((2, 2), np.uint8)pre_ = cv2.erode(pre_, kernel, iterations=1)kernel = np.ones((3, 3), np.uint8)pre_ = cv2.dilate(pre_, kernel, iterations=1)# resize成输入规格pre_ = cv2.resize(pre_, (28, 28))_, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)# 转换为SVM输入格式pre_ = np.array(pre_).flatten().reshape(1, -1)return pre_'''@brief  在空白背景上显示提取出的roi@param res_list: roi列表@return None
'''
def ShowRoi(res_list):back = 255 * np.ones((1000, 1500, 3), dtype=np.uint8)# 图片x轴偏移量tlx = 0for roi in res_list:if tlx + roi.shape[1] > back.shape[1]:break# 每次在原图上加上一个roiback[0:roi.shape[0], tlx:tlx + roi.shape[1], :] = roitlx += roi.shape[1]cv2.imshow("show", back)cv2.waitKey(0)'''@brief  寻找数字轮廓并提取roi@param src: 输入图像@param thin: 是否为细字体@param thresh: 二值化阈值@return roi列表
'''
def FindNumbers(src, thin=True):# 拷贝dst = src.copy()paint = src.copy()roi = src.copy()dst = ~dst# 预处理paint = cv2.resize(paint, (448, 448))dst = cv2.resize(dst, (448, 448))# 记录缩放比例,后来看这一步好像没啥意义fx = src.shape[1] / 448fy = src.shape[0] / 448# 转单通道dst = cv2.cvtColor(dst, cv2.COLOR_BGR2GRAY)# 边缘检测后二值化,直接二值化的话由于采光不同的原因灰度直方图峰与峰之间可能会差距过大,导致二值图的分割不准确# 而边缘检测对像素突变更加敏感,因此采用Canny边缘检测后二值化cv2.Canny(dst, 200, 200, dst)# 对于平常笔写的字太细,膨胀一下if thin:kernel = np.ones((5, 5), np.uint8)dst = cv2.dilate(dst, kernel, iterations=1)# 寻找外围轮廓contours, _ = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)# 提取roiroi_list = []rect_list = []for contour in contours:rect = cv2.boundingRect(contour)if not ((rect[2] * rect[3] < 400 or rect[2] * rect[3] > 448 * 448 / 2.5) or (rect[3] < rect[2])):cv2.rectangle(paint, rect, (255, 0, 0), 1)x_min = rect[0] * fxx_max = (rect[0] + rect[2]) * fxy_min = rect[1] * fyy_max = (rect[1] + rect[3]) * fyroi_list.append(roi[int(y_min):int(y_max), int(x_min):int(x_max)].copy())rect_list.append(rect)return paint, roi_list, rect_list'''@brief  以txt形式显示数据@param data: 数据集@return None   
'''
def ShowDataTxt(data):print("----------------------------------------------------------")for i in range(28):for j in range(28):print(0 if data[0][i * 28 + j] == 255 else 1, end='')print('\n')print("----------------------------------------------------------")if __name__ == "__main__":# 加载patch_sklearn()model_path = "./SVC_C1_enhance.pkl"if os.path.exists(model_path):print("Model Exist, Load Form Serialization")model = LoadSvcModel(model_path)else:print("Model Do Not Exist, Train")# 训练model = TrainSvc(1, False)# 保存SaveSvcModel(model, model_path)# 测试paint, nums, rects = FindNumbers(cv2.imread("test_final.jpg"))predict_nums = []for img in nums:data = PreProcessThinFont(img, show=False)# ShowDataTxt(data)predict_nums.append(model.predict(data)[0])for i in range(len(predict_nums)):cv2.putText(paint,str(predict_nums[i]), (rects[i][0], rects[i][1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)cv2.imshow("show", paint)cv2.waitKey(0)

给出几个识别后的效果:
image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/44731.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

3.26随笔

SELECT DISTINCT 实例 下面的 SQL 语句仅从 "Websites" 表的 "country" 列中选取唯一不同的值,也就是去掉 "country" 列重复值:实例SELECT DISTINCT country FROM Websites;输出结果:

KVM虚拟化

KVM虚拟化 ============================================================= 0.环境介绍 宿主机:内存4G+ 纯净的系统CentOS-7 1:什么是虚拟化? 虚拟化,通过模拟计算机的硬件,来实现在同一台计算机上同时运行多个不同的操作系统的技术。2:为什么要用虚拟化? 2.1:虚拟化…

利用大模型服务一线小哥的探索与实践

一、小哥作业+大模型 2022年OpenAI基于GPT推出了聊天机器人ChatGPT,带来了非常惊艳的语言理解、内容生成、知识推理等能力,能够准确理解人的语言、意图,并能够回答出清晰、完整的内容,让人很难分辨出沟通交流的是人类还是机器人。 大模型会尝试基于已有的内容,生成内容的延…

腾讯云+Ollama部署远程访问大模型api

Ollama是个极为方便的大模型框架 1.腾讯云上选购合适的云服务器,为了方便拉取模型,地区建议选择北美(计费模式选择按量计费是为了省钱,老板有钱的话随意)架构选择异构计算镜像选择Ubuntu22.04,驱动版本默认就行,云硬盘默认50G即可网络默认分配即可,一定要选择分配独立公网IP,否…

3.21随笔

SELECT Column 实例 下面的 SQL 语句从 "Websites" 表中选取 "name" 和 "country" 列:实例SELECT name,country FROM Websites;输出结果为:

中电金信:银行业数据中心何去何从

​ 20多年前,计算机走进国内大众视野,计算机行业迎来在国内的高速发展时代。银行业是最早使用计算机的行业之一,也是计算机技术应用最广泛、最深入的行业之一。近年来,随着银行竞争加剧,科技如何引领业务、金融科技如何发展,直接关系到银行的生存空间和发展命脉。银行业I…

团队作业sprint第九天

2024-05-04 项目任务进展: 5小时(46/50) 会议照片 过去一天完成了哪些任务完成AI对话的测试接下来的计划优化各个页面 继续学习flutter和Springboot还剩下哪些任务优化主页面 专栏功能的管理的优化 测试整个软件遇到了哪些困难出现了一些奇怪的dug 边学习边进行功能开发问题多多…

团队作业sprint第十天

2022-05-06项目任务进展: 4小时(50/50) 会议照片 过去一天完成了哪些任务完成专栏的测试 优化主页面 专栏功能的管理的优化 测试整个软件遇到了哪些困难出现了一些奇怪的dug 边学习边进行功能开发问题多多 Springboot的学习很困难,经常遇到很多问题一直在网上查找解决相关问题 …