pytorch的spectral_norm的使用
利用pytorch自带的频谱归一化函数,给设定好的网络进行频谱归一化。主要用于生成对抗网络的鉴别器,效果还是很明显的。import torchimport torch.nn as nnclass TestModule(nn.Module):def __init__(self):super(TestModule,self).__init__()self.layer1 = nn.Conv2d(16,3
·
利用pytorch自带的频谱归一化函数,给设定好的网络进行频谱归一化。主要用于生成对抗网络的鉴别器,效果还是很明显的。
import torch
import torch.nn as nn
class TestModule(nn.Module):
def __init__(self):
super(TestModule,self).__init__()
self.layer1 = nn.Conv2d(16,32,3,1)
self.layer2 = nn.Linear(32,10)
self.layer3 = nn.Linear(32,10)
def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
model = TestModule()
def add_sn(m):
for name, layer in m.named_children():
m.add_module(name, add_sn(layer))
if isinstance(m, (nn.Conv2d, nn.Linear)):
return nn.utils.spectral_norm(m)
else:
return m
my_model = add_sn(model)
更多推荐
已为社区贡献1条内容
所有评论(0)