深度学习实战5-卷积神经网络(CNN)中文OCR识别项目

2023-11-02

文章目录

一、前期工作

  1. 导入库
  2. 图片生成函数
  3. 导入数据
  4. 生成数据集函数

二、CNN模型建立
三、训练模型函数
四、训练模型与结果
五、验证

大家好,我是微学AI,今天给大家带来一个利用卷积神经网络(CNN)进行中文OCR识别,实现自己的一个OCR识别工具。
一个OCR识别系统,其目的很简单,只是要把影像作一个转换,使影像内的图形继续保存、有表格则表格内资料及影像内的文字,一律变成计算机文字,使能达到影像资料的储存量减少、识别出的文字可再使用及分析,这样可节省人力打字的时间。
中文OCR识别的注意流程图:
在这里插入图片描述

一、前期工作

1.导入库

import numpy as np 
from PIL import Image, ImageDraw, ImageFont
import cv2
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import plot_model
import matplotlib.pyplot as plt

#导入字体
DroidSansFallbackFull = ImageFont.truetype("DroidSansFallback.ttf", 36, 0)
fonts = [DroidSansFallbackFull,]

2.数据集生成函数

#生成图片,48*48大小
def affineTrans(image, mode, size=(48, 48)):
    # print("AffineTrans ...")
    if mode == 0:  # padding移动
        which = np.array([0, 0, 0, 0])
        which[np.random.randint(0, 4)] = np.random.randint(0, 10)
        which[np.random.randint(0, 4)] = np.random.randint(0, 10)
        image = cv2.copyMakeBorder(image, which[0], which[0], which[0], which[0], cv2.BORDER_CONSTANT, value=0)
        image = cv2.resize(image, size)
    if mode == 1:
        scale = np.random.randint(48, int(48 * 1.4))
        center = [scale / 2, scale / 2]
        image = cv2.resize(image, (scale, scale))
        image = image[int(center[0] - 24):int(center[0] + 24), int(center[1] - 24):int(center[1] + 24)]

    return image

#图片处理 除噪
def noise(image, mode=1):
    # print("Noise ...")
    noise_image = (image.astype(float) / 255) + (np.random.random((48, 48)) * (np.random.random() * 0.3))
    norm = (noise_image - noise_image.min()) / (noise_image.max() - noise_image.min())
    if mode == 1:
        norm = (norm * 255).astype(np.uint8)
    return norm
    
#绘制中文的图片
def DrawChinese(txt, font):
    # print("DrawChinese...")
    image = np.zeros(shape=(48, 48), dtype=np.uint8)
    x = Image.fromarray(image)
    draw = ImageDraw.Draw(x)
    draw.text((8, 2), txt, (255), font=font)
    result = np.array(x)
    return result

#图片标准化
def norm(image):
    # print("norm...")
    return image.astype(np.float) / 255

3.导入数据

char_set = open("chinese.txt",encoding = "utf-8").readlines()[0]
print(len(char_set[0]))  # 打印字的个数

4.生成数据集函数

# 生成数据:训练集和标签
def Genernate(batchSize, charset):
    # print("Genernate...")
    #    pass
    label = [];
    training_data = [];

    for x in range(batchSize):
        char_id = np.random.randint(0, len(charset))
        font_id = np.random.randint(0, len(fonts))
        y = np.zeros(dtype=np.float, shape=(len(charset)))
        image = DrawChinese(charset[char_id], fonts[font_id])
        image = affineTrans(image, np.random.randint(0, 2))
        # image = noise(image)
        # image = augmentation(image,np.random.randint(0,8))
        image = noise(image)
        image_norm = norm(image)
        image_norm = np.expand_dims(image_norm, 2)

        training_data.append(image_norm)
        y[char_id] = 1
        label.append(y)

    return np.array(training_data), np.array(label)

def Genernator(charset,batchSize):
    print("Generator.....")

    while(1):
        label = [];
        training_data = [];
        for i in range(batchSize):
            char_id = np.random.randint(0, len(charset))
            font_id = np.random.randint(0,len(fonts))
            y = np.zeros(dtype=np.float,shape=(len(charset)))
            image = DrawChinese(charset[char_id],fonts[font_id])
            image = affineTrans(image,np.random.randint(0,2))
            #image = noise(image)
            #image = augmentation(image,np.random.randint(0,8))
            image = noise(image)
            image_norm = norm(image)
            image_norm  = np.expand_dims(image_norm,2)
            y[char_id] = 1
            training_data.append(image_norm)
            label.append(y)

        y[char_id] = 1
        yield (np.array(training_data),np.array(label))

二、CNN模型建立

def Getmodel(nb_classes):

    img_rows, img_cols = 48, 48
    nb_filters = 32
    nb_pool = 2
    nb_conv = 4

    model = Sequential()
    print("sequential..")
    model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
                            padding='same',
                            input_shape=(img_rows, img_cols, 1)))
    print("add convolution2D...")
    model.add(Activation('relu'))
    print("activation ...")
    model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
    model.add(Dropout(0.25))
    model.add(Convolution2D(nb_filters, nb_conv, nb_conv,padding='same'))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    return model

#评估模型
def eval(model, X, Y):
    print("Eval ...")
    res = model.predict(X)

三、训练模型函数

#训练函数
def Training(charset):
    model = Getmodel(len(charset))
    while (1):
        X, Y = Genernate(64, charset)
        model.train_on_batch(X, Y)
        print(model.loss)

#训练生成模型
def TrainingGenerator(charset, test=1):

    set = Genernate(64, char_set)
    model = Getmodel(len(charset))
    BatchSize = 64
    model.fit_generator(generator=Genernator(charset, BatchSize), steps_per_epoch=BatchSize * 10, epochs=15,
                        validation_data=set)

    model.save("ocr.h5")

    X = set[0]
    Y = set[1]
    if test == 1:
        print("============6 Test == 1 ")
        for i, one in enumerate(X):
            x = one
            res = model.predict(np.array([x]))
            classes_x = np.argmax(res, axis=1)[0]
            print(classes_x)

            print(u"Predict result:", char_set[classes_x], u"Real result:", char_set[Y[i].argmax()])
            image = (x.squeeze() * 255).astype(np.uint8)
            cv2.imwrite("{0:05d}.png".format(i), image)

四、训练模型与结果

TrainingGenerator(char_set)  #函数TrainingGenerator 开始训练
Epoch 1/15
640/640 [==============================] - 63s 76ms/step - loss: 8.1078 - accuracy: 3.4180e-04 - val_loss: 8.0596 - val_accuracy: 0.0000e+00
Epoch 2/15
640/640 [==============================] - 102s 159ms/step - loss: 7.5234 - accuracy: 0.0062 - val_loss: 6.2163 - val_accuracy: 0.0781
Epoch 3/15
640/640 [==============================] - 38s 60ms/step - loss: 5.9793 - accuracy: 0.0425 - val_loss: 4.1687 - val_accuracy: 0.3281
Epoch 4/15
640/640 [==============================] - 45s 71ms/step - loss: 5.0450 - accuracy: 0.0889 - val_loss: 3.1590 - val_accuracy: 0.4844
Epoch 5/15
640/640 [==============================] - 37s 58ms/step - loss: 4.5251 - accuracy: 0.1292 - val_loss: 2.5326 - val_accuracy: 0.5938
Epoch 6/15
640/640 [==============================] - 38s 60ms/step - loss: 4.1708 - accuracy: 0.1687 - val_loss: 1.9666 - val_accuracy: 0.7031
Epoch 7/15
640/640 [==============================] - 35s 54ms/step - loss: 3.9068 - accuracy: 0.1951 - val_loss: 1.8039 - val_accuracy: 0.7812
...

910
Predict result: 妻 Real result: 妻
1835
Predict result: 莱 Real result: 莱
3107
Predict result: 阀 Real result: 阀
882
Predict result: 培 Real result: 培
1241
Predict result: 鼓 Real result: 鼓
735
Predict result: 豆 Real result: 豆
1844
Predict result: 巾 Real result: 巾
1714
Predict result: 跌 Real result: 跌
2580
Predict result: 骄 Real result: 骄
1788
Predict result: 氧 Real result: 氧

生成字体图片:
在这里插入图片描述

五、验证

model = tf.keras.models.load_model("ocr.h5")
img1 = cv2.imread('00001.png',0)
img = cv2.resize(img1,(48,48))

print(img.shape)
img2 = tf.expand_dims(img, 0)
res = model.predict(img2)
classes_x = np.argmax(res, axis=1)[0]
print(classes_x)
print(u"Predict result:", char_set[classes_x]) 

中文字:
在这里插入图片描述
预测结果为”莱;
数据集的获取私信我!后期有更深入的OCR识别功能呈现,敬请期待!

往期作品:

深度学习实战项目

1.深度学习实战1-(keras框架)企业数据分析与预测

2.深度学习实战2-(keras框架)企业信用评级与预测

3.深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

4.深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别

5.深度学习实战5-卷积神经网络(CNN)中文OCR识别项目

6.深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测

7.深度学习实战7-电商产品评论的情感分析

8.深度学习实战8-生活照片转化漫画照片应用

9.深度学习实战9-文本生成图像-本地电脑实现text2img

10.深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)

11.深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例

12.深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正

13.深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星

14.深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了

15.深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问

16.深度学习实战16(进阶版)-虚拟截图识别文字-可以做纸质合同和表格识别

17.深度学习实战17(进阶版)-智能辅助编辑平台系统的搭建与开发案例

18.深度学习实战18(进阶版)-NLP的15项任务大融合系统,可实现市面上你能想到的NLP任务

19.深度学习实战19(进阶版)-ChatGPT的本地实现部署测试,自己的平台就可以实现ChatGPT

…(待更新)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习实战5-卷积神经网络(CNN)中文OCR识别项目 的相关文章

随机推荐

  • [Android开发] 点击小图查看大图的功能、图片浏览器,聊天页面、朋友圈、微博、广场页面的查看大图功能,浏览大图功能,拖动关闭大图页面,完美的甚至完胜微信的过渡动画

    一 简介 属于你的侵入性低的大图查看器 完美的甚至完胜微信的过渡动画 同样支持视频 另外您可以自定义加载图片的内核 例如Glide Picasso或其他的 github地址 https github com FlyJingFish Open
  • Microsoft Minecraft产品线梳理

    目录 前言 1 Minecraft游戏本体 2 Minecraft Education Edition 3 Minecraft Marketplace 4 Minecraft Realms 5 Minecraft Dungeons
  • linux下线程使用

    什么是线程 线程 是进程内部的一个控制序列 即使不使用线程 进程内部也有一个执行线程 为什么要使用多线程 使用fork创建进程以执行新的任务 该方式的代价很高 多个进程间不会直接共享内存 线程是进程的基本执行单元 一个进程的所有任务都在线程
  • conda配置清华源

    conda配置清华源 问题描述 在Anaconda下准备下载东西的时候 出现了这样的报错或者出现下载速度极其慢 原因 因为默认下载东西的链接是国外的镜像 存在一些问题 因此将源改成国内的镜像即可 解决方案 1 使用清华的镜像链接 用清华镜像
  • Java 随机产生四位验证码

    import java util Random public class RandGenDemo public static void main String args 静态的方法不用new 直接对象 方法名 System out prin
  • linux 视频教程 韦山东,韦东山 linux 设备树详解

    简 介 设备树视频录制完毕 29节 现在只要69元 学员对此课程的评价 这是最翔实最实惠最精益求精的设备树教程 感兴趣的了解一下 以下是课程详情 设备树是什么 设备树是一种机制 用文本的方式描述硬件资源 我们写驱动前要先看原理图 确定硬件连
  • Listener

    观察者设计模式 它是事件驱动的一种体现形式 就好比在做什么事情的时候被人盯着 当对应做到某件事时 触发事件 观察者模式通常由以下三部分组成 1 事件源 触发事件的对象 2 事件 触发的动作 里面封装了事件源 3 监听器 当事件源触发事件时
  • postman本地测试接口的地址路径,如何获取和拼接

    首先 在本地进行接口自测 那么就是本地的ip 既 http localhost 其他就是其他的ip地址 接着是端口号以及所添加的共用路径 我们可以从springboot项目的配置文件application yml获取 找到 server 服
  • vue3中Cron表达式的使用

    效果
  • Python中多进程间通信(multiprocessing.Manager)

    Python中写多进程的程序 一般都使用multiprocesing模块 进程间通讯有多种方式 包括信号 管道 消息队列 信号量 共享内存 socket等 这里主要介绍使用multiprocessing Manager模块实现进程间共享数据
  • 爬虫的大概思路

    爬虫一般来说两种 一种是页面分析 分析页面获取整理出数据 毕竟是要展示数据在页面 获取这些从页面上 另一种是获取对应接口 通过API方式来获取 因为归根到底 都是前台后端交互发送请求响应请求 两种方式各有优劣 方式一应该是比较常见的 但是从
  • 断言简介说明

    转自 断言简介说明 下文笔者讲述断言简介说明 如下所示 断言简介 在Java中 assert关键字是从Java 4开始引入的 为了避免和老版本的Java代码中使用了assert关键字导致错误 Java在执行的时候默认是不启动断言检查的 这个
  • 后端代码审计——PHP数组

    文章目录 PHP数组 1 索引数组 2 关联数组 3 数组创建 3 1 直接赋值 3 2 array 语言结构 4 多维数组 4 1 创建多维数组 5 数组元素访问 5 1 组元素操作 5 2 元素操作 5 3 数组的遍历 5 4 for
  • win11 安装 Anaconda(2022.10)+pycharm(2022.3/2023.1.4)+配置虚拟环境

    目录 一 安装Anaconda 二 Anaconda配置环境变量 三 Anaconda更改虚拟环境安装路径 创建虚拟环境 四 安装pycharm 五 pycharm配置Anaconda环境 一 安装Anaconda 1 下载 官网慢 可以选
  • app渗透-外在信息收集

    app渗透 外在信息收集 5 外在信息收集 5 1外在抓包 frida r0capture 5 1 1 frida的安装和使用 1 安装 2 使用测试 5 2 1 r0capture使用 5 外在信息收集 5 1外在抓包 frida r0c
  • 问题 D: 数据结构练习 -- 栈的操作

    题目描述 对输入整数序列1 2 3 执行一组栈操作 输出操作的出栈序列 输入 每行是一个测试用例 表示一个操作序列 操作序列由P和Q两个符号组成 P表示入栈 Q表示出栈 每个操作序列长度不超过1000 输出 对每个操作序列 输出出栈序列 若
  • MySQL——数据的增删改

    2023 9 12 本章开始学习DML 数据操纵语言 语言 相关学习笔记如下 DML语言 数据操作语言 插入 insert 修改 update 删除 delete 一 插入语句 方式一 经典的插入 语法 insert into 表名 列名
  • 详解JavaNIO Buffer类的属性和方法

    前言 我们知道 Java中的NIO实际上使用的是多种IO模型中的IO多路复用策略 在NIO中 引入了Buffer缓冲区 Channel通道 Selector选择器三个概念 现在先看一下Buffer缓冲区的一些基本知识 介绍 NIO的Buff
  • mac os如何使用rz、sz

    1 什么是rz sz 在线上真实生产环境中总会有上传文件到服务器 以及从服务器下载文件的需求 rz sz应用广泛 由于发送和接收都是在服务器上进行的 所以 rz received 接收 意味着向服务器上传 sz send 发送 意味着从服务
  • 深度学习实战5-卷积神经网络(CNN)中文OCR识别项目

    文章目录 一 前期工作 导入库 图片生成函数 导入数据 生成数据集函数 二 CNN模型建立 三 训练模型函数 四 训练模型与结果 五 验证 大家好 我是微学AI 今天给大家带来一个利用卷积神经网络 CNN 进行中文OCR识别 实现自己的一个