pytorch 人脸识别

2023-11-16

import torch
import os
import numpy as np
import torch.nn as nn
import  matplotlib.pyplot as plt
import time
import torchvision
from torchvision import  transforms,models,datasets
import torch.optim as optim

#训练集在train文件夹下,每种类别的人脸都位于同一个子目录下。验证集数据类似
data_dir="F:/muct人脸数据库_项目"
train_dir=data_dir+"/train"
valid_dir=data_dir+"/valid"


data_transform=transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ]
    )

#加载数据
image_datasets={
    x:torchvision.datasets.ImageFolder(os.path.join(data_dir,x),
    transform=data_transform)for x in ["train","valid"]
}

dataLoaders={x:torch.utils.data.DataLoader(image_datasets[x],
             batch_size=4,shuffle=True)for x in ["train","valid"]
}

#使用GPU
device=torch.device("cuda")

#加载模型
model=torchvision.models.resnet152(True)
#冻住模型中的参数
for param in model.parameters():
    param.requires_grad=False

#修改最后的全连接层,使其适应咱们的项目-276分类
#in_features是全连接层中的输入的维数
num_fts=model.fc.in_features
model.fc=nn.Linear(num_fts,276)


#将模型加载到GPU
model=model.to(device)

#设置优化器
optimizer=optim.Adam(model.fc.parameters(),lr=1e-2)

#损失函数
criterion=nn.CrossEntropyLoss()

for epoch in range(5):
    print("Epoch:", epoch)
    print("---" * 5)
    for phase in ["train","valid"]:
        rightnumber = 0
        rightacc = 0
        if phase =="train":
            model.train()
        else:
            model.eval()

         # 把数据都取个遍
        for inputs, labels in dataLoaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 清零
            optimizer.zero_grad()
            # 只有训练的时候计算和更新梯度
            with torch.set_grad_enabled(phase == 'train' or phase == 'valid'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # 计算损失
                rightnumber+= torch.sum(preds == labels.data)

            rightacc = rightnumber.double() / len(dataLoaders[phase].dataset)
        print(rightacc.item())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch 人脸识别 的相关文章

随机推荐

  • 20-文件下载及读取漏洞

    WEB 漏洞 文件操作之文件下载读取全解 思维导图 1 文件被解析 则是文件解析漏洞 2 显示源代码 则是文件读取漏洞 3 提示文件下载 则是文件下载漏洞 文件下载漏洞 利用条件 1 存在读文件的函数和操作 2 读取文件的路径用户可控且未校
  • Android 保存资源图片到相册最新写法适用于Android10.0及以上

    博主前些天发现了一个巨牛的人工智能学习网站 通俗易懂 风趣幽默 忍不住也分享一下给大家 点击跳转到网站 一 首先在AndroidManifest xml中加入权限
  • APFS 文件系统探究

    本文的创作初衷是因为我发现从底层详解 APFS 的资料很少 所以自己来进行了一些探究和整理 一点说明 如果你在看 APFS 的文档或者其他内容 不要把高层级的分区理解成 Windows 中的分区 因为 APFS 里卷 Volume 才是显示
  • OUT指令时,就进入了I/O端口读写周期

    1 译码电路的输入信号 每当CPU执行IN或者OUT指令时 就进入了I O端口读写周期 此时首先是端口地址有效 然后是I O读写控制信号 IOR和 IOW有效 把对端口地址译码而产生的译码信号同 IOR和 IOW结合起来一同控制对I O端口
  • 聊聊FFT

    关于FFT 全称为快速傅里叶变换 目的是把时域的信号转变为频域的信号 具体的科学解释及计算方程组可以去查百度百科 不过小编不建议这么做 因为查了也看不懂的 先看一张都能看懂的图 这是某种食物的配方表 每种配方包含了多少比例标注的很清楚 对于
  • 计算机网络教程_第二章物理层_整理与复习

    计算机网络教程 第一章 概述 第二章 物理层 第三章 数据链路层 提示 写完文章后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 计算机网络教程 1 物理层的作用及主要任务 2 数据传输的方式 并行 串行 异步 同步 P40 3
  • python 设置下载源,全局设置

    推荐使用豆瓣的 个人感觉最好用 当然 你如果喜欢其它的 也可以设置 pip config set golbal index url https pypi douban com simple 设置成功 windows 提示的配置文件在 ini
  • Spyder上使用tensorflow训练完成时出现SystemExit异常

    使用spyder tensorflow实现迁移学习训练inception v3网络 训练完成后提示 SystemExit home zhijuan anaconda3 lib site packages python3 6 site pac
  • 深度学习 图像分割综述

    文章目录 前言 语义分割 实例分割 技术路线 掩膜建议分类法 先检测再分割法 标记像素后聚类法 密集滑动窗口法 参考 前言 图像分割在计算机视觉中是个重要的任务 在地理信息系统 医学影像 自动驾驶 机器人等领域都有着很重要的应用技术支持作用
  • TensorFlow框架做实时人脸识别小项目(二)

    在第一部分中 分析了整个小项目的体系 重点讨论了用于人脸检测对齐的mtcnn网络的实现原理 并利用笔记本电脑自带的摄像头进行了测试 今天在这里要讨论的重点是人脸识别中的核心部分 facenet网络 facenet是Google开源的人脸识别
  • 从CPU cache一致性的角度看Linux spinlock的不可伸缩性(non-scalable)

    凌晨一点半的深圳雨夜 豪雨当夜惊起有人赏 笑叹落花无声空飘零 喜欢这种豪雨 让人兴奋 惊起作文以呜呼之感叹 引用上一篇文章 优化多核CPU的TCP新建连接性能 重排spinlock https blog csdn net dog250 ar
  • 图片<img>、链接<a>等去除referer标记

    1 img 标签 img src src
  • 2011年北京大学计算机研究生机试真题(题解)

    九度OJ题目传送门 2011年北京大学计算机研究生机试真题 鸡兔同笼 题目描述 一个笼子里面关了鸡和兔子 鸡有2只脚 兔子有4只脚 没有例外 已经知道了笼子里面脚的总数a 问笼子里面至少有多少只动物 至多有多少只动物 输入 第1行是测试数据
  • 存储路径_存储多路径

    今天的话题是存储多路径 三国开篇 天下大势 合久必分 分久必合 我觉得用来形容多路径也非常贴切 它可以将多条路径整合成一条 也可以在单条路径出现问题时迅速切换 先简单介绍下多路径 IT存储系统在构建的时候 为了最大化保证安全 通常会采用冗余
  • C++实现——LCS-最大公共子串长度

    求两个字符串的最长公共子串的长度 子串不一定是原串中的连续子串组成 LCS 使用动态规划 include
  • Python基础知识及概念

    Python基础知识及概念 1 注释 单行注释 这是一个单行注释 在程序开发时 同样可以使用 在代码的后面 旁边 增加说明性的文字 但是 需要注意的是 为了保证代码的可读性 注释和代码之间 至少要有 两个空格 示例代码如下 print he
  • Vue-Quill-Editor 设置编辑器中文字的默认字体大小

    Vue Quill Editor 默认字体看起来有些小 如下 设置默认字体大小 ql container 设置默认字号 font size 16px 设置之后
  • 利用jsqlparser解析SQL语句

    时常会遇到很多情况 我们需要对SQL语句进行替换或者拼接 以往我们可能会用StringBuild来进行拼接 StringBuilder sql new StringBuilder sql append select from sql app
  • 开发框架Furion之Winform+SqlSugar

    目录 1 开发环境 2 项目搭建 2 1 创建WinFrom主项目 2 2 创建子项目 2 3 实体类库基础类信息配置 2 3 1 Nuget包及项目引用 2 3 2 实体基类创建 2 4 仓储业务类库基础配置 2 4 1 Nuget包及项
  • pytorch 人脸识别

    import torch import os import numpy as np import torch nn as nn import matplotlib pyplot as plt import time import torch