说明
使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即C:\用户名\.cache\torch\.checkpoints
下,便于统一管理,我决定修改model的存放路径,在网上找了很久都没有很好的解决方法,只能自己尝试,现将解决方案给出,供大家参考~
操作环境
- windows10 + Anaconda
- torch:1.1.0
- torchvision:0.3.0
加载方式
以加载vgg16为例,首先定义网络结构
class CNN(nn.Module):
def __init__(self, usegpu=True):
super(CNN, self).__init__()
self.model = models.__dict__['vgg16'](pretrained=True)
self.model = nn.Sequential(*list(self.model.children())[0])
self.model = nn.Sequential(*list(self.model.children())[:16])
def forward(self, x):
x = self.model(x)
return x
在执行model = CNN()
时即会触发下载动作,此时会默认下载到C:\用户名\.cache\torch\.checkpoints
下
修改方法
总的原则就是修改源码
- 经过这篇博客的介绍,发现pytorch的默认下载路径是由
load_state_dict_from_url
函数进行控制,那么就好办了,只需要找到这个函数进行修改即可 - 由于我是下载
vgg16
,所以我先找到vgg.py
源码,位于python路径/torchvision/models/vgg.py
(其中python路径即是你安装python所在的地址,假设你是用Anaconda创建了一个名为envtest
的虚拟环境,那么所有安装的库都会在Anaconda/envs/envtest/Lib/site-packages
这个文件夹下) - 接下来就是套娃操作
- 在
vgg.py
直接搜索load_state_dict_from_url
发现有如下语句
![vgg.py](https://img-blog.csdnimg.cn/20200311113548501.png)
所以需要在本目录下找到utils.py
,看看里面有没有load_state_dict_from_url
函数 - 打开
utils.py
,只有如下代码
![utils.py](https://img-blog.csdnimg.cn/20200311113744680.png)
显然,需要到torch安装路径
下找到hub.py
3.找到hub.py
,搜索load_state_dict_from_url
,成功找到如下代码
![hub.py](https://img-blog.csdnimg.cn/20200311113958361.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L1Byb0xvdmVyOTg=,size_16,color_FFFFFF,t_70)
根据描述,可发现model_dir
参数即为下载模型的默认路径,所以直接将model_dir = None
换成model_dir = 想要的模型下载绝对路径
即可,感兴趣的同学可以仔细专研,这里就不过多阐述
后记
这里说明一下我上面这么做的原因:
- 为什么需要找到源码写死?直接使用函数修改不行吗?
因为考虑到后续可能还要下载其他预训练模型,与其每次进行修改,还不如一次性写死 - 为什么需要
绝对路径
,相对路径不行吗?
相对路径也可以,但考虑到绝对路径更加直观,所以我这里使用的绝对路径
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)