What You Need to Know About Deep Residual Shrinkage Networks
-
What is a deep residual shrinkage network? Why was this concept proposed? What are its core steps? This article addresses these questions by discussing related research on deep residual shrinkage networks and shares insights with readers.
The deep residual network (ResNet) won the Best Paper Award at the 2016 CVPR conference and has been cited 38,295 times on Google Scholar to date.
The deep residual shrinkage network is a novel improved version of the deep residual network, essentially integrating deep residual networks, attention mechanisms, and soft thresholding functions.
To some extent, the working principle of the deep residual shrinkage network can be understood as: using the attention mechanism to identify unimportant features and setting them to zero via the soft thresholding function, or identifying important features and retaining them, thereby enhancing the ability of deep neural networks to extract useful features from noisy signals.
First, when classifying samples, noise is inevitably present, such as Gaussian noise, pink noise, or Laplacian noise. More broadly, samples may contain information irrelevant to the current classification task, which can also be considered noise. This noise may adversely affect classification performance. (Soft thresholding is a key step in many signal denoising algorithms.)
For example, when chatting by the roadside, the conversation may be mixed with sounds like car horns or wheels. When performing speech recognition on these sound signals, the recognition accuracy will inevitably be affected by these noises.
From a deep learning perspective, the features corresponding to these noises (e.g., car horns or wheels) should be eliminated within the deep neural network to avoid impacting speech recognition performance.
Second, even within the same dataset, the noise levels of individual samples often vary. (This aligns with the concept of attention mechanisms. For instance, in an image dataset, the positions of target objects may differ across images; the attention mechanism can focus on the target object's location for each image.)
For example, when training a cat-dog classifier, among five images labeled "dog," the first may include a dog and a mouse, the second a dog and a goose, the third a dog and a chicken, the fourth a dog and a donkey, and the fifth a dog and a duck.
During training, the classifier will inevitably be disturbed by irrelevant objects like mice, geese, chickens, donkeys, and ducks, reducing classification accuracy. If we can identify these irrelevant objects and eliminate their corresponding features, the classifier's accuracy may improve.
Soft thresholding is a core step in many signal denoising algorithms, eliminating features with absolute values below a certain threshold and shrinking features with absolute values above the threshold toward zero. It can be implemented using the following formula:
The derivative of the soft thresholding output with respect to the input is:
From this, we see that the derivative of soft thresholding is either 1 or 0, a property shared with the ReLU activation function. Thus, soft thresholding can also mitigate the risks of gradient vanishing and explosion in deep learning algorithms.
In the soft thresholding function, the threshold must meet two conditions: first, it must be positive; second, it must not exceed the maximum value of the input signal, otherwise the output will be zero.
Additionally, the threshold should ideally meet a third condition: each sample should have its own independent threshold based on its noise level.
This is because noise levels often vary across samples. For example, in the same dataset, Sample A may have less noise, while Sample B has more. In such cases, Sample A should use a higher threshold for denoising, while Sample B should use a lower one.
In deep neural networks, although these features and thresholds lose explicit physical meaning, the underlying principle remains the same: each sample should have its own threshold based on its noise level.
The attention mechanism is relatively easy to understand in computer vision. An animal's visual system can quickly scan an entire area, identify a target object, and focus attention on it to extract more details while suppressing irrelevant information. For details, refer to articles on attention mechanisms.
Squeeze-and-Excitation Network (SENet) is a newer deep learning method under the attention mechanism. In different samples, different feature channels contribute differently to the classification task. SENet uses a small subnetwork to obtain a set of weights, which are then multiplied with the features of each channel to adjust their magnitudes.
This process can be seen as applying varying levels of attention to different feature channels.
In this approach, each sample has its own unique set of weights. In other words, any two samples will have different weights. In SENet, the weights are obtained via the path: "global pooling → fully connected layer → ReLU → fully connected layer → Sigmoid."
The deep residual shrinkage network borrows this SENet subnetwork structure to implement soft thresholding under the attention mechanism. The subnetwork within the blue box learns a set of thresholds to apply soft thresholding to each feature channel.
In this subnetwork, the absolute values of all input feature maps are first calculated. Then, global average pooling is applied to obtain a single feature, denoted as A. In another path, the feature map after global average pooling is fed into a small fully connected network. The last layer of this network uses a Sigmoid function to normalize the output between 0 and 1, yielding a coefficient, α. The final threshold is expressed as α×A.
Thus, the threshold is a number between 0 and 1 multiplied by the average absolute value of the feature map. This ensures the threshold is positive and not excessively large.
Moreover, different samples have different thresholds. Therefore, to some extent, this can be understood as a special attention mechanism: identifying features irrelevant to the current task and setting them to zero via soft thresholding, or identifying relevant features and retaining them.
Finally, by stacking a certain number of basic modules along with convolutional layers, batch normalization, activation functions, global average pooling, and fully connected output layers, the complete deep residual shrinkage network is constructed.
The deep residual shrinkage network is essentially a general-purpose feature learning method. This is because, in many feature learning tasks, samples often contain some noise or irrelevant information, which may affect learning performance. For example:
- In image classification, if an image includes many other objects, these can be considered "noise." The deep residual shrinkage network may use the attention mechanism to identify such noise and apply soft thresholding to zero out the corresponding features, potentially improving classification accuracy.
- In speech recognition, in noisy environments like streets or factory floors, the deep residual shrinkage network may improve accuracy or provide a strategy to do so.