利用pytorch自带的频谱归一化函数,给设定好的网络进行频谱归一化。主要用于生成对抗网络的鉴别器,效果还是很明显的。

import torch
import torch.nn as nn
 
class TestModule(nn.Module):
    def __init__(self):
        super(TestModule,self).__init__()
        self.layer1 = nn.Conv2d(16,32,3,1)
        self.layer2 = nn.Linear(32,10)
        self.layer3 = nn.Linear(32,10)
 
    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)

model = TestModule()

def add_sn(m):
        for name, layer in m.named_children():
             m.add_module(name, add_sn(layer))
        if isinstance(m, (nn.Conv2d, nn.Linear)):
             return nn.utils.spectral_norm(m)
        else:
             return m
         
my_model = add_sn(model)
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐