CNN计算网络各层参数量与FLOPs

发布于:2021-09-13 20:46:20

CNN中经常需要考虑到网络的参数量与计算量等问题,具体的计算方法为:



其中,K是卷积核的大小,Cin核Cout表示输入与输出的通道数,H与W表示特征图的大小。


此外,可以通过python中的stat模块计算与验证,代码:


import torch
import torch.nn as nn
from torchstat import stat

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential()
self.conv.add_module("conv", nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, bias=False, padding=1))
self.conv.add_module("bn", nn.BatchNorm2d(64))
# self.conv.add_module("")
self.fc = nn.Linear(64*224*224,100, bias=True)

def forward(self, x):
x = self.conv(x)
x = x.view(-1, 64*224*224)
x = self.fc(x)
return x




model = Net()
stat(model, (3, 224, 224))

输出为:



注意:torchstst的版本需要为0.0.6,安装方法:? pip install torchstat==0.0.6


?


?

相关推荐