Deep learning: Understanding how neural networks work
Beneath the AI we all know and love is the humble neural network. Here's what they are, how they're used and architected, and how they learn over time.
May 29, 2024 • 11 Minute Read
With the buzz around artificial intelligence (AI) these days, tools like ChatGPT seem to get all the attention. But the real secret sauce behind AI is the neural network. Whether it’s self-driving cars, spam detection, route optimization, or just zhuzhing up your photos for Instagram, it’s all made possible by the humble neural net.
In this article, we’ll demystify neural networks, digging into core concepts like neurons, layers, weights and biases, learning rates, and backpropagation.
Table of contents
What is a neural network?
If you’ve heard anything about a neural network, it’s probably something like, “It’s designed to work like the human brain.” Our brain is made of neurons—hence, a “neural” network.
If your neuroscience is a little rusty, don’t worry. In reality, it’s just an algorithm, or set of instructions, that’s used in deep learning to do things like image recognition, object detection, fraud detection, and natural language processing (NLP).
But wait, there’s another term I snuck in there: deep learning. What’s that? Is it the same as machine learning (ML)?
Perhaps an illustration will help differentiate between artificial intelligence, machine learning, and deep learning.
Artificial intelligence is at the top level. This is an overarching term that means we’re enabling computers to mimic human behavior, like a computer playing chess against a human.
Machine learning is a subset of AI. ML is about algorithms using data to learn and improve performance over time. For instance, you pass in data about what credit card fraud looks like, the computer learns it, and then the computer can predict if a new incoming transaction is fraudulent.
Deep learning is a subset of ML. Here’s where the neural networks come in. This category of machine learning is typically used for more complex problems than “regular” machine learning, and it uses more data and more computing power. Using a structure/algorithm like the human brain gives us the power we need.
What are common types of neural networks?
There are several types of neural networks, and each has a niche based on the data and problem you’re trying to solve. Here are a few of the more common networks and what they’re used for.
Convolutional neural networks (CNN): Primarily used for image and video recognition, image classification, medical image analysis, and natural language processing
Recurrent neural networks (RNN): Best for sequential data, such as time series analysis (think stock market analysis or weather forecasting), language modeling, and speech recognition
Generative adversarial networks (GANs): Used for generating new data that resembles the training data; popular in image generation, photo enhancement, and creating realistic art. Visit https://thisxdoesnotexist.com to see GANs in action. GANS are also leveraged in cybersecurity.
Transformer: Designed to handle sequential data, like text or time-series data, in a very efficient and effective manner. ChatGPT is built using a transformer neural network. Kesha Williams has written a great article about Transformers in Gen AI.
What are the key components of a neural network?
Now that we understand what neural networks are and what they’re used for, let’s talk about how they’re made.
What are neurons and layers in a neural network?
Here again, an illustration might be helpful. This example network is used to do image detection, specifically on a picture of an elephant.
The input layer: The input layer is what you pass in. In our case, it’s an image, but even an image is really just data in the eyes of the computer. It could also be data like credit card transactions, insurance claims, or pictures of stop signs.
The hidden layers: Here’s where the work happens to identify the image (or whatever it is the neural network is meant to do). In our example, each layer is concerned with a different part of the elephant. One is focused only on color, another on counting the number of legs, and so on.
The output layer: This is the “answer” the network gives after it runs the data through the hidden layers. This is an elephant or it’s not. This is credit card fraud or it’s not. That kind of thing.
So what are the neurons?
To understand neurons (sometimes called nodes), let’s go to an even simpler example: an image of a smiley face. Remember that a computer sees images as data. In the case of this image, it “sees” each pixel in the image as a binary of 0 (for black) or 1 (for white).
Because the image is 7 pixels by 7 pixels, that means we have 49 (7x7) pieces of data to feed into the network. Each piece of data becomes a neuron in the input layer. So effectively, a neuron holds some value in the network.
Pulling everything together in a more realistic example of identifying a handwritten “2,” we have something that looks like this:
How do weights impact a neural network?
In the hidden layers of a neural network, not all things are created equal. Going back to our elephant example, some characteristics of the elephant are more important—or hold more weight—than others.
For instance, lots of animals have four legs, so we shouldn’t give “has four legs” a lot of weight when trying to identify the elephant. But the trunk or tusks would hold more weight, as these are more unique to an elephant.
So, we’ve added weights in the yellow circles in the image below.
As the network does its thing in the hidden layers, some characteristics are given more weight than others, helping us get to a more accurate prediction in the output layer.
What is the purpose of activation functions?
Another important concept to understand in neural networks is the activation function. In simple terms, the activation function decides which information should move forward through the network and how much of it gets through.
Let’s look at another example, this one for predicting whether someone will buy travel insurance when they purchase a trip.
Step 1: The neuron receives inputs, each input having its own weight or importance.
Step 2: It calculates a sum of these inputs (the sigmoid symbol in the middle of the diagram), adjusting for their weights.
Next up, the activation function part . . .
Step 3: The activation function looks at the sum and decides the output of the neuron. We’re going to use a simple function called a step function that says “if the sum is greater than a threshold of .5, then return 1; otherwise, return 0.” (Threshold is a value that we can set, and we chose .5.)
And the final answer . . .
Our sum of .7 is greater than .5, so in this example, we’d return 1 (true) and say that for this transaction, the traveler would buy travel insurance.
For this illustration, we used a simple activation function to return a binary value of 0 or 1. But there are many other activation functions. Here are some of the more common ones:
Sigmoid: Takes in a value and returns a probability between 0 and 1. For example: There’s a probability of 91% that the handwritten digit is a 2.
Rectified linear unit function (ReLU): Takes in a value X and returns the max of 0 or X. For example: Pass in a negative value X, and get a return value of 0; pass in a positive value X, and get a return value of X.
Softmax: Used on the last layer of a multi-class classification problem. For example: Classify whether something is a dog, cat, chicken, or goat.
How do neural networks learn?
You have a neural network with its neurons and layers and activation functions. But how does it actually “learn” anything?
Learning comes from training. And training just means we provide lots and lots of labeled (i.e., “this is an elephant”) examples to the network until it “learns” and has a high rate of accuracy making predictions. The actual process of “learning” happens through backpropagation.
What is backpropagation?
To give you some insight into how this works, let’s see a human example. Imagine a teacher gives a group of school students puzzle pieces of an elephant and asks them to figure out what it is.
The students get into groups, with each group assigned a section of the puzzle. Once they come up with an answer about their section, they take it to their teacher’s assistant. (These are the “hidden layers” of the network.) The teacher’s assistant gathers up all the answers, takes it to the head teacher, and they have an overall answer of “we don’t think this is an elephant.”
The head teacher goes to the all-knowing label (the source of truth). The label says this IS an elephant, and tells the teacher to go fix the answer.
The teacher goes back to the teacher’s assistants, who take it back to the students, telling them they got the incorrect answer and they need to adjust to get the correct one.
This process of moving back through the network and correcting errors is called backpropagation, short for “backward propagation of errors.”
What is the purpose of loss functions?
Before you can backpropagate through a network and correct your errors, you need to know what to correct and by how much.
In neural network terms, we need to figure out how far off we were in our final answer. This is called the loss (or sometimes called error or cost).
For example, maybe we predicted that there’s a probability of .92 that this is an elephant, and the real answer should have been 1. Or perhaps it was really bad in its prediction and said there’s a probability of only .35.
Pulling in some math, the distance between what we wanted and what we got in these two examples is:
1 - .92 = .08 (a small loss or error; the model is pretty close)
1 - .35 = .65 (a large-ish loss or error; the model needs a lot of improvement)
Calculating the distance between what you wanted and what you predicted is done with a loss function.
Mean squared error (MSE): Used for regression problems like predicting continuous values of home or stock prices
Cross-entropy loss: Used for classification problems like “is this a cat or a dog?” or “is this email spam or not?”
What is gradient descent?
Once you know how far off you are, you need to work towards improving that number. In other words, you want to minimize the loss (error/cost). The process of finding the minimum is known as gradient descent, sometimes called stochastic gradient descent.
Kind of like descending a mountain to get to the bottom, the gradient descent algorithm will try to try to find the optimal point at the bottom of the curve by descending the gradient of this line, iterating until it finds that lowest point.
What is a learning rate?
As gradient descent is making its “steps” down the curve, the learning rate is effectively the size of its steps. And in general, you want to use small steps so you don’t miss something.
For example, if you had a large learning rate, and you were trying to find the optimal point on this line, you’d “step over” the gap down and miss the optimal point.
But with a smaller learning rate (i.e., smaller “steps”), you make a nice descent down the line to the optimal point.
In mathematical terms, the learning rate is used to figure out how you’re going to update the weights in the neural network as you’re doing backpropagation. It’s used in this formula: derivative of the loss divided by the derivative of the weight times the learning rate.
This will result in a very small number, which means you’re making tiny updates to weights as you move back through the network. For example, with a single neuron and weight, perhaps by the time you get back to the top of the network, a weight is updated from 0.3 to 0.3008, and then you make a forward pass through the network again using the new weights.
Once you make it to the end, calculate the loss function again, figure out how much to update weights, then backpropagate to update them. This forward and backpropagation continues until you’ve minimized the overall loss for the network and get accurate predictions.
Learning more about neural networks
There you have it! There’s a LOT more to neural networks, but hopefully this article has given you a good overall sense of what they’re used for, how they’re architected, and how they learn and improve over time.
If you want to dig deeper on this topic, check out these other resources from Pluralsight.