Artistic Neural Style Transfer with PyTorch
Jun 9, 2020 • 9 Minute Read
Overview
In this guide, you will implement the algorithm on Neural Network for Artistic Style Transfer (NST) in PyTorch. This algorithm will allow you to get a Picasso-style image. It does so by creating a new image that mixes the style (painting) of one image and the content (input image) of the other.
Thanks to Leon A. Gatys, et al. for their contribution in their paper, A Neural Algorithm of Artistic Style.
To make the model faster and more accurate, a pre-trained VGG-Net-19 (Visual Geometry Group) is used. The model is trained on ImageNet images and can be downloaded from Pytorch API.
The theory behind artistic neural transfer has been covered in previous guides that were based on the TensorFlow framework. It is highly recommended that you go through these guides first. Part 1 talks about theoretical aspects and VGG-Net, and Part 2 talks about losses involved in creating AI digital art.
The guide will be a code walkthrough of the PyTorch implementation. Below is an outline of the process.
Importing Libraries
To work with PyTorch, import the torch library. The PIL image library will manipulate the image. From the function torchvision, you will import model class and call for vgg19 model.
%matplotlib inline
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import requests
from torchvision import transforms, models
Build the Model
Load the VGG-Net-19 model and keep pretrained=True. The number 19 denotes the number of layers involved in the network. VGG has its use in the classification problem (face detection) as well. But in NST, you are only dealing with features. The param.requires_grad_() will freeze all VGG parameters since you're only optimizing the target image.
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
param.requires_grad_(False)
But there's a catch in PyTorch! Here you have to check if GPU is available in your system. If it is, then move the vgg model to GPU. If it's not, then the model will run on CPU. And if you don't have a GPU? No worries--you can use Google Colab.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
The smaller the image size, the faster the processing speed will be. The transform class of torchvision plays an important role while pre-processing and normalizing the image.
def load_image(img_path, max_size=400, shape=None):
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
image = in_transform(image)[:3,:,:].unsqueeze(0)
return image
style = load_image("Picasso.jpg").to(device)
content = load_image("houses.png").to(device)
Display the tensor as an image.
def im_convert(tensor):
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
Display the images side-by-side.
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax1.set_title("Content-Image",fontsize = 20)
ax2.imshow(im_convert(style))
ax2.set_title("Style-Image", fontsize = 20)
plt.show()
print(vgg)
Intermediate Layers for Style and Content
Deeper layers of VGG-19 will extract the best and most complex features. Hence, conv4_2 is assigned to extract content components. From each block, the first convolution layers (shallow layers) i.e. from conv1_1 to conv5_1 detects multiple features like lines or edges. Refer to the image below.
def get_features(image, model, layers=None):
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'30': 'conv5_2', #content
'28': 'conv5_1'}
features = {}
x = image
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
Loss Functions
Calculate the gram matrices for each layer. You will need the houses and the lake in the target image. Start by cloning the content image and then iteratively change its style.
def gram_matrix(tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
target = content.clone().requires_grad_(True).to(device)
Assigning Weights
Weights are assigned on each style layer. Weight the earlier layers with a higher number to get the larger style artifacts.
style_weights = {'conv1_1': 1.5,
'conv2_1': 0.80,
'conv3_1': 0.25,
'conv4_1': 0.25,
'conv5_1': 0.25}
content_weight = 1e-2
style_weight = 1e9
Run the Model
These weights are used in the optimizer (Adam) to reduce the loss of the model. Define steps to update the image. Putting everything together: call the features from the VGG-Net and calculate the content loss. Get the style representation to calculate the style loss. It will weight the layer appropriately before adding it to other layers. Finally, calculate the total loss!
In a gradient, descent for NN weights are adjusted, but in NST, they are kept fixed. Instead, image pixels are adjusted. The gradients concerning the distance measure will be backpropagated to the inputs, thus transforming the inputs.
show = 400
optimizer = optim.Adam([target], lr=0.01)
steps = 7000
for ii in range(1, steps+1):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
style_loss = 0
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss / (d * h * w)
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i % show == 0:
print('Total loss: ', total_loss.item())
plt.imshow(im_convert(target))
plt.show()
Conclusion
Great work! The loss is reducing after every epoch. Play with your creation by adjusting weights and learning rate. You may go for an LBFGS optimizer.
This guide gave you a general idea of how to code NST using PyTorch. I would recommend going through the NST using TensorFlow for a better understanding of the terms involved (losses, VGG-Net, cost function, etc.).
I hope you learned something new today. If you need any help with machine learning implementations, feel free to contact me.