RuntimeError: expected scalar type Double but found Float
自带的type还是不够好用,用数据类型自带的方法或者属性访问更加清晰,pytorch默认的数据类型是float,numpy默认是float64或者叫double,两者不能混合使用,matlab默认是double好像,除此之外,还要注意gpu上的tensor操作以及tensor的类型(变量还是常量),是否需要梯度等等,tensor之间的计算操作等。(pytorch17) [stu514-17@ser
自带的type还是不够好用,用数据类型自带的方法或者属性访问更加清晰,pytorch默认的数据类型是float,numpy默认是float64或者叫double,两者不能混合使用,matlab默认是double好像,除此之外,还要注意gpu上的tensor操作以及tensor的类型(变量还是常量),是否需要梯度等等,tensor之间的计算操作等。
(pytorch17) [stu514-17@server5 ANN2SNN_tool_chain]$ python ann2snn.py example_net --weight_bitwidth 16 --timesteps 64 --finetune_epochs 10 --ann_weight checkpoint/example_net.pth --save_file out_snn.pth --num_workers 1 --batch_size 200 --test_batch_size 200 --device '0' --hardware 'cpu'
/data/student/stu514-17/anaconda3/envs/pytorch17/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /opt/conda/conda-bld/pytorch_1607369981906/work/c10/cuda/CUDAFunctions.cpp:100.)
return torch._C._cuda_getDeviceCount() > 0
Files already downloaded and verified
ExampleNet0(
Epoch: 0
/data/student/stu514-17/code_uzip/ANN2SNN_tool_chain/quantization.py:106: UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, *, Number alpha) (Triggered internally at /opt/conda/conda-bld/pytorch_1607369981906/work/torch/csrc/utils/python_arg_parser.cpp:882.)
d_p.add_(weight_decay, weight_data)
59 250 Loss: 0.061 | Acc: 98.167% (11780/12000)
119 250 Loss: 0.057 | Acc: 98.263% (23583/24000)
179 250 Loss: 0.057 | Acc: 98.264% (35375/36000)
239 250 Loss: 0.055 | Acc: 98.325% (47196/48000)
/data/student/stu514-17/anaconda3/envs/pytorch17/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:156: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
add op view1: ['relu5_out1']->['view1_out1']
add op fc1: ['view1_out1']->['fc1_out1']
end_nodes(outputs of the model): ['fc1_out1']
set conv1: Vthr tensor([4.5568]) bias tensor([0.0356, 0.0356, 0.0356], grad_fn=<SliceBackward>) in_scales 1.0
set avg_pool2d1: Vthr tensor([4.1724]) bias tensor([0.0326, 0.0326, 0.0326], grad_fn=<SliceBackward>) in_scales 4.5568132400512695
set conv2: Vthr tensor([4.8173]) bias tensor([0.0376, 0.0376, 0.0376], grad_fn=<SliceBackward>) in_scales 4.172365665435791
set avg_pool2d2: Vthr tensor([4.2452]) bias tensor([0.0332, 0.0332, 0.0332], grad_fn=<SliceBackward>) in_scales 4.817252159118652
set conv3: Vthr tensor([3.4593]) bias tensor([0.0270, 0.0270, 0.0270], grad_fn=<SliceBackward>) in_scales 4.245199203491211
set conv4: Vthr tensor([1.8026]) bias tensor([0.0141, 0.0141, 0.0141], grad_fn=<SliceBackward>) in_scales 3.459290027618408
set conv5: Vthr tensor([1.1337]) bias tensor([0.0089, 0.0089, 0.0089], grad_fn=<SliceBackward>) in_scales 1.8026059865951538
set fc1: Vthr tensor([13.5833]) bias tensor([0.1061, 0.1061, 0.1061], grad_fn=<SliceBackward>) in_scales 1.1336519718170166
Performing validation for SNN
/data/student/stu514-17/code_uzip/ANN2SNN_tool_chain/validation.py:60: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
mask = torch.tensor(mask,dtype=torch.float)
Mean Firing ratios 0.13229199999736416, Firing ratios: [0.09463893 0.10301569 0.08803387 0.09965187 0.062542 0.08998064
0.2368007 0.2368007 0.17916359]
SNN Prec@1 83.890, Prec@5 98.730, Time 287.67187, Loss: 0.663
最近在使用 sequitur库 快速搭建自编码器时遇到 RuntimeError: expected scalar type Double but found Float
涉及代码
import torch
from sequitur.models import LINEAR_AE
model = LINEAR_AE(
input_dim=300,
encoding_dim=20,
h_dims=[120, 60],
h_activ=None,
out_activ=None
)
model
错误的意思是期待的张量类型是double但是输入的是float(有时候这个提示不一定是准确的,但归根就是类型错误),可以将所有的层的输入输出类型打印出来:
for name, param in model.named_parameters():
print(name,'-->',param.type(),'-->',param.dtype,'-->',param.shape)
发现全部都是tensor.float32
类型,再检查一下自己的输入数据类型,发现不对,使用类型转换 valset = torch.tensor(valset,dtype=torch.float)
即可。
# 构建模型
class Net(nn.Module):
# 实例化模型
net = Net()
# 查看模型的各层的尺寸
for name,parameters in net.named_parameters():
print(name,':',parameters.size())
1、看准包名
torchviz(pytorch17) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ python snn.py --batch_size 200 --architecture 'VGG16' --pretrained_snn './trained_models/snn/snn_vgg16_cifar10.pth' --test_only --epochs 200 --timesteps 200
Traceback (most recent call last):
File "snn.py", line 13, in <module>
from torchviz import make_dot
ModuleNotFoundError: No module named 'torchviz'
(pytorch17) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ pip install pytorchviz
ERROR: Could not find a version that satisfies the requirement pytorchviz (from versions: none)
ERROR: No matching distribution found for pytorchviz
(pytorch17) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ python
Python 3.8.13 (default, Mar 28 2022, 11:38:47)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
2、pytorch不同版本的容错规则,默认处理不一样
(36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(37): ReLU(inplace=True)
(38): Dropout(p=0.3, inplace=False)
)
(classifier): Sequential(
(0): Linear(in_features=2048, out_features=4096, bias=False)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=False)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=10, bias=False)
)
)
)
Adam (
Parameter Group 0
amsgrad: True
betas: (0.9, 0.999)
eps: 1e-08
lr: 0.0001
weight_decay: 0.0005
)snn.py:182: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
mask = torch.tensor(mask,dtype=torch.float)
Traceback (most recent call last):
File "snn.py", line 519, in <module>
test(epoch)
File "snn.py", line 186, in test
data = data * mask
RuntimeError: expected device cpu and dtype Float but got device cuda:0 and dtype Float
(pytorch) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ python
3、命令行参数-
的写法无所谓
architecture : VGG16
weight_decay : 0.0005
pretrained_snn : ./trained_models/snn/snn_vgg16_cifar10.pth
dropout : 0.3
alpha : 0.3
scaling_factor : 0.7
lr_reduce : 10
kernel_size : 3Files already downloaded and verified
Files already downloaded and verified
/data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'self_models.spiking_model.VGG_SNN_STDB' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.AvgPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Loaded module.features.0.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.3.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.6.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.9.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.12.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.15.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.18.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.21.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.24.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.27.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.30.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.33.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.features.36.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.classifier.0.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.classifier.3.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
Loaded module.classifier.6.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
DataParallel(
(module): VGG_SNN_STDB(
(input_layer): PoissonGenerator()
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): ReLU(inplace=True)
(2): Dropout(p=0.3, inplace=False)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): ReLU(inplace=True)
(5): AvgPool2d(kernel_size=2, stride=2, padding=0)
(6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(7): ReLU(inplace=True)
更多参考
https://blog.csdn.net/Hodors/article/details/118411462
https://blog.csdn.net/sinat_29957455/article/details/80412885
https://blog.csdn.net/h102897/article/details/124085184
https://blog.csdn.net/weixin_43633568/article/details/104520073
https://blog.csdn.net/zhangqiqiyihao/article/details/114699910
https://blog.csdn.net/jcjic/article/details/117996638
更多推荐
所有评论(0)