• 问答
  • 技术
  • 实践
  • 资源
关于 pytorch 中.eval () 的问题
技术讨论

from future import print_function
import torch
from PIL import Image
import numpy as np
from models.u_net import UNet
from torch.autograd import Variable
from torchvision.models.segmentation import deeplabv3_resnet50

from models.resnet import resnet34
from torchvision.transforms import ToPILImage

---------------------------------------------------------------

val_path = './cloud/test/'
predict_path = 'cloud/pred-Unet-deconv-newdata-pooling-nn.conv-2bs/'
number_photo = 1352
model = UNet(3,1).cuda()

model_path = 'checkpoint/Unet-deconv-newdata-pooling-nn.conv-2bs/model/netG_178.pth'

model.load_state_dict(torch.load(model_path))
##model=model.eval()

---------------------------------------------------------------

transform1 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
for i in range(1,number_photo+1):

test_image = Image.open(val_path+ str(i) + '.png').convert('RGB')

test_image = Image.open(val_path + str(i) + '.png')

img = transform1(test_image)
img = img.unsqueeze(0)
img = img.cuda()
img = Variable(img)

label_image = model(img)

label_image = label_image.squeeze(0)

label_image[label_image >= 0.5] = 1
label_image[label_image < 0.5] = 0
label_image = label_image.cpu()
a = transforms.ToPILImage()(label_image)
a.save(predict_path+ str(i) + '.png')
print('转换第%d张图' % (i))
i = i + 1
请问一下在我的测试代码中加上model.eval()和去除model.eval()结果差别非常大,没有加的话,在不同模型上都能测试出结果,而且结果也会有差别,加上之后效果非常差。请问一下这是什么原因造成的?

有人说是我的网络结构问题,下采样不i能将self.relu(self.bn())这么嵌套

 c1 = self.relu(self.bn64(self.in_conv1(input)))
    c2 = self.relu(self.bn64(self.in_conv2(c1)))

    e1_ = self.relu(self.bn128(self.conv1_1(self.pooling1(c2))))
    e1 = self.relu(self.bn128(self.conv1_2(e1_)))

    e2_ = self.relu(self.bn256(self.conv2_1(self.pooling1(e1))))
    e2 = self.relu(self.bn256(self.conv2_2(e2_)))

    e3_ = self.relu(self.bn512(self.conv3_1(self.pooling1(e2))))
    e3 = self.relu(self.bn512(self.conv3_2(e3_)))

    e4_ = self.relu(self.bn1024(self.conv4_1(self.pooling1(e3))))
    e4 = self.relu(self.bn1024(self.conv4_2(e4_)))

    d1_1 = self.upsampling1(e4)
    d1_2 = torch.cat([d1_1, e3], dim=1)
    d1_ = self.relu(self.bn512(self.conv5_1(d1_2)))
    d1 = self.relu(self.bn512(self.conv5_2(d1_)))

    d2_1 = self.upsampling2(d1)
    d2_2 = torch.cat([d2_1, e2], dim=1)
    d2_ = self.relu(self.bn256(self.conv6_1(d2_2)))
    d2 = self.relu(self.bn256(self.conv6_2(d2_)))

    d3_1 = self.upsampling3(d2)
    d3_2 = torch.cat([d3_1, e1], dim=1)
    d3_ = self.relu(self.bn128(self.conv7_1(d3_2)))
    d3 = self.relu(self.bn128(self.conv7_2(d3_)))

    d4_1 = self.upsampling4(d3)
    d4_2 = torch.cat([d4_1,c2], dim=1)
    d4_ = self.relu(self.bn64(self.conv8_1(d4_2)))
    d4 = self.relu(self.bn64(self.conv8_2(d4_)))

    output = self.sigmoid(self.out_conv(d4))

    return output

    麻烦各位大神给出指点
  • 0
  • 6
  • 551
收藏
暂无评论
Silencewang
大咖

新疆大学 ·

  • 0

    关注
  • 0

    获赞
  • 0

    精选文章
文章专栏
  • Silencewang的专栏
作者文章
更多
  • 关于 pytorch 中.eval () 的问题
    551