首页 文章 python code

python code

2024-06-17 19:34  浏览数:421  来源:Libra    

# This is the code from the p1ch2/3_cyclegan notebook
import torch
import torch.nn as nn
class ResNetBlock(nn.Module):
def __init__(self, dim):
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim)
def build_conv_block(self, dim):
conv_block = []
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True),
]
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class ResNetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
assert n_blocks >= 0
super(ResNetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True),
]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [
nn.Conv2d(
ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=1,
bias=True,
),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True),
]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
nn.ConvTranspose2d(
ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=True,
),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
# here we move to 0-1 input and 0-1 output
# usually one would think about writing this differently
# for efficiency (e.g. absorbing the 255 into the first conv
return self.model(input * 255) / 2 + 0.5
def get_pretrained_model(model_path, map_location=None):
netG = ResNetGenerator()
model_data = torch.load(model_path, map_location=map_location)
netG.load_state_dict(model_data)
netG.eval()
for p in netG.parameters():
netG.requires_grad_(False)
return netG
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Call as {} zebra_weights.pt traced_zebra_model.pt".format(sys.argv[0]))
sys.exit(1)
model = get_pretrained_model(sys.argv[1], map_location="cpu")
traced_model = torch.jit.trace(model, torch.randn(1, 3, 227, 227))
traced_model.save(sys.argv[2])
# img = Image.open("../data/p1ch2/horse.jpg")
# out_img.save('../data/p1ch2/zebra.jpg')



声明:以上文章均为用户自行添加,仅供打字交流使用,不代表本站观点,本站不承担任何法律责任,特此声明!如果有侵犯到您的权利,请及时联系我们删除。

字符:    改为:
去打字就可以设置个性皮肤啦!(O ^ ~ ^ O)