自带的type还是不够好用,用数据类型自带的方法或者属性访问更加清晰,pytorch默认的数据类型是float,numpy默认是float64或者叫double,两者不能混合使用,matlab默认是double好像,除此之外,还要注意gpu上的tensor操作以及tensor的类型(变量还是常量),是否需要梯度等等,tensor之间的计算操作等。

(pytorch17) [stu514-17@server5 ANN2SNN_tool_chain]$ python ann2snn.py example_net  --weight_bitwidth 16 --timesteps 64 --finetune_epochs 10 --ann_weight checkpoint/example_net.pth --save_file out_snn.pth --num_workers 1 --batch_size 200 --test_batch_size 200 --device '0' --hardware 'cpu'
/data/student/stu514-17/anaconda3/envs/pytorch17/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at  /opt/conda/conda-bld/pytorch_1607369981906/work/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0
Files already downloaded and verified
ExampleNet0(
Epoch: 0
/data/student/stu514-17/code_uzip/ANN2SNN_tool_chain/quantization.py:106: UserWarning: This overload of add_ is deprecated:
        add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
        add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607369981906/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  d_p.add_(weight_decay, weight_data)
59 250 Loss: 0.061 | Acc: 98.167% (11780/12000)
119 250 Loss: 0.057 | Acc: 98.263% (23583/24000)
179 250 Loss: 0.057 | Acc: 98.264% (35375/36000)
239 250 Loss: 0.055 | Acc: 98.325% (47196/48000)
/data/student/stu514-17/anaconda3/envs/pytorch17/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:156: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
add op view1: ['relu5_out1']->['view1_out1']
add op fc1: ['view1_out1']->['fc1_out1']
end_nodes(outputs of the model): ['fc1_out1']
set conv1: Vthr tensor([4.5568]) bias tensor([0.0356, 0.0356, 0.0356], grad_fn=<SliceBackward>) in_scales 1.0
set avg_pool2d1: Vthr tensor([4.1724]) bias tensor([0.0326, 0.0326, 0.0326], grad_fn=<SliceBackward>) in_scales 4.5568132400512695
set conv2: Vthr tensor([4.8173]) bias tensor([0.0376, 0.0376, 0.0376], grad_fn=<SliceBackward>) in_scales 4.172365665435791
set avg_pool2d2: Vthr tensor([4.2452]) bias tensor([0.0332, 0.0332, 0.0332], grad_fn=<SliceBackward>) in_scales 4.817252159118652
set conv3: Vthr tensor([3.4593]) bias tensor([0.0270, 0.0270, 0.0270], grad_fn=<SliceBackward>) in_scales 4.245199203491211
set conv4: Vthr tensor([1.8026]) bias tensor([0.0141, 0.0141, 0.0141], grad_fn=<SliceBackward>) in_scales 3.459290027618408
set conv5: Vthr tensor([1.1337]) bias tensor([0.0089, 0.0089, 0.0089], grad_fn=<SliceBackward>) in_scales 1.8026059865951538
set fc1: Vthr tensor([13.5833]) bias tensor([0.1061, 0.1061, 0.1061], grad_fn=<SliceBackward>) in_scales 1.1336519718170166
Performing validation for SNN
/data/student/stu514-17/code_uzip/ANN2SNN_tool_chain/validation.py:60: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  mask = torch.tensor(mask,dtype=torch.float)
Mean Firing ratios 0.13229199999736416, Firing ratios: [0.09463893 0.10301569 0.08803387 0.09965187 0.062542   0.08998064
 0.2368007  0.2368007  0.17916359]
SNN Prec@1 83.890, Prec@5 98.730, Time 287.67187, Loss: 0.663

最近在使用 sequitur库 快速搭建自编码器时遇到 RuntimeError: expected scalar type Double but found Float
在这里插入图片描述
涉及代码

import torch
from sequitur.models import LINEAR_AE
model = LINEAR_AE(
  input_dim=300,
  encoding_dim=20,
  h_dims=[120, 60],
  h_activ=None,
  out_activ=None
)
model

 
 
 
 

    在这里插入图片描述
    错误的意思是期待的张量类型是double但是输入的是float(有时候这个提示不一定是准确的,但归根就是类型错误),可以将所有的层的输入输出类型打印出来:

    for name, param in model.named_parameters():
        print(name,'-->',param.type(),'-->',param.dtype,'-->',param.shape)
    
     
     
     
     

      在这里插入图片描述
      学习

      发现全部都是tensor.float32 类型,再检查一下自己的输入数据类型,发现不对,使用类型转换 valset = torch.tensor(valset,dtype=torch.float) 即可。

      另外torch 查看模型参数

      # 构建模型
      class Net(nn.Module):
      # 实例化模型
      net = Net()
      # 查看模型的各层的尺寸
      for name,parameters in net.named_parameters():
          print(name,':',parameters.size())
      
       
       
       
       

        1、看准包名

        torchviz(pytorch17) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ python snn.py --batch_size 200 --architecture 'VGG16' --pretrained_snn './trained_models/snn/snn_vgg16_cifar10.pth' --test_only --epochs 200 --timesteps 200
        Traceback (most recent call last):
          File "snn.py", line 13, in <module>
            from torchviz import make_dot
        ModuleNotFoundError: No module named 'torchviz'
        (pytorch17) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ pip install pytorchviz
        ERROR: Could not find a version that satisfies the requirement pytorchviz (from versions: none)
        ERROR: No matching distribution found for pytorchviz
        (pytorch17) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ python
        Python 3.8.13 (default, Mar 28 2022, 11:38:47)
        [GCC 7.5.0] :: Anaconda, Inc. on linux
        Type "help", "copyright", "credits" or "license" for more information.
        

        2、pytorch不同版本的容错规则,默认处理不一样

              (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (37): ReLU(inplace=True)
              (38): Dropout(p=0.3, inplace=False)
            )
            (classifier): Sequential(
              (0): Linear(in_features=2048, out_features=4096, bias=False)
              (1): ReLU(inplace=True)
              (2): Dropout(p=0.5, inplace=False)
              (3): Linear(in_features=4096, out_features=4096, bias=False)
              (4): ReLU(inplace=True)
              (5): Dropout(p=0.5, inplace=False)
              (6): Linear(in_features=4096, out_features=10, bias=False)
            )
          )
        )
         Adam (
        Parameter Group 0
            amsgrad: True
            betas: (0.9, 0.999)
            eps: 1e-08
            lr: 0.0001
            weight_decay: 0.0005
        )snn.py:182: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
          mask = torch.tensor(mask,dtype=torch.float)
        Traceback (most recent call last):
          File "snn.py", line 519, in <module>
            test(epoch)
          File "snn.py", line 186, in test
            data = data * mask
        RuntimeError: expected device cpu and dtype Float but got device cuda:0 and dtype Float
        (pytorch) [stu514-17@server5 code_Spiking_CNN_Rathi_hybrid]$ python
        

        3、命令行参数-的写法无所谓

              architecture         : VGG16
                 weight_decay         : 0.0005
                 pretrained_snn       : ./trained_models/snn/snn_vgg16_cifar10.pth
                 dropout              : 0.3
                 alpha                : 0.3
                 scaling_factor       : 0.7
                 lr_reduce            : 10
                 kernel_size          : 3Files already downloaded and verified
        Files already downloaded and verified
        /data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
          warnings.warn(msg, SourceChangeWarning)
        /data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'self_models.spiking_model.VGG_SNN_STDB' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
          warnings.warn(msg, SourceChangeWarning)
        /data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
          warnings.warn(msg, SourceChangeWarning)
        /data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
          warnings.warn(msg, SourceChangeWarning)
        /data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
          warnings.warn(msg, SourceChangeWarning)
        /data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.AvgPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
          warnings.warn(msg, SourceChangeWarning)
        /data/student/stu514-17/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:453: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
          warnings.warn(msg, SourceChangeWarning)
        
         Loaded module.features.0.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.3.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.6.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.9.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.12.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.15.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.18.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.21.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.24.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.27.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.30.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.33.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.features.36.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.classifier.0.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.classifier.3.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         Loaded module.classifier.6.weight from ./trained_models/snn/snn_vgg16_cifar10.pth
         DataParallel(
          (module): VGG_SNN_STDB(
            (input_layer): PoissonGenerator()
            (features): Sequential(
              (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU(inplace=True)
              (2): Dropout(p=0.3, inplace=False)
              (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (4): ReLU(inplace=True)
              (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
              (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (7): ReLU(inplace=True)
        

        更多参考

        https://blog.csdn.net/Hodors/article/details/118411462
        https://blog.csdn.net/sinat_29957455/article/details/80412885
        https://blog.csdn.net/h102897/article/details/124085184
        https://blog.csdn.net/weixin_43633568/article/details/104520073
        https://blog.csdn.net/zhangqiqiyihao/article/details/114699910
        https://blog.csdn.net/jcjic/article/details/117996638

        Logo

        为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

        更多推荐