更新時(shí)間:2023-07-21 來源:黑馬程序員 瀏覽量:
ResNet(Residual Network)是由Kaiming He等人提出的深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),它在2015年的ImageNet圖像識(shí)別競賽中取得了非常顯著的成績,引起了廣泛的關(guān)注。ResNet的主要貢獻(xiàn)是解決了深度神經(jīng)網(wǎng)絡(luò)的梯度消失問題,使得可以訓(xùn)練更深的網(wǎng)絡(luò),從而獲得更好的性能。
問題:在傳統(tǒng)的深度神經(jīng)網(wǎng)絡(luò)中,隨著網(wǎng)絡(luò)層數(shù)的增加,梯度在反向傳播過程中逐漸變小,導(dǎo)致淺層網(wǎng)絡(luò)的權(quán)重更新幾乎沒有效果,難以訓(xùn)練。這被稱為梯度消失問題。
ResNet的解決方法:ResNet引入了“殘差塊”(residual block),每個(gè)殘差塊包含了一條“跳躍連接”(shortcut connection),它允許梯度能夠直接穿過塊,從而避免了梯度消失問題。因此,深度網(wǎng)絡(luò)可以通過恒等映射(identity mapping)來學(xué)習(xí)殘差,使得網(wǎng)絡(luò)在增加深度時(shí)反而變得更容易訓(xùn)練。
ResNet結(jié)構(gòu)特點(diǎn):
1.殘差塊:每個(gè)殘差塊由兩個(gè)或三個(gè)卷積層組成,它們的輸出通過跳躍連接與塊的輸入相加,形成殘差(residual)。
2.跳躍連接:跳躍連接允許梯度直接流過塊,有助于避免梯度消失問題。
3.批量歸一化:ResNet中廣泛使用批量歸一化層來加速訓(xùn)練并穩(wěn)定網(wǎng)絡(luò)。
4.殘差塊堆疊:ResNet通過堆疊多個(gè)殘差塊來構(gòu)建深層網(wǎng)絡(luò)。深度可以根據(jù)任務(wù)的復(fù)雜性而自由選擇。
接下來我們看一個(gè)簡化的ResNet代碼演示(使用TensorFlow):
import tensorflow as tf from tensorflow.keras import layers, models # 定義一個(gè)基本的殘差塊 def residual_block(x, filters, downsample=False): # 如果downsample為True,使用步長為2的卷積層實(shí)現(xiàn)降采樣 stride = 2 if downsample else 1 # 記錄輸入,以便在跳躍連接時(shí)使用 identity = x # 第一個(gè)卷積層 x = layers.Conv2D(filters, kernel_size=3, strides=stride, padding='same')(x) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) # 第二個(gè)卷積層 x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(x) x = layers.BatchNormalization()(x) # 如果進(jìn)行了降采樣,需要對(duì)identity進(jìn)行相應(yīng)處理,保證維度一致 if downsample: identity = layers.Conv2D(filters, kernel_size=1, strides=stride, padding='same')(identity) identity = layers.BatchNormalization()(identity) # 跳躍連接:將卷積層的輸出與輸入相加 x = layers.add([x, identity]) x = layers.Activation('relu')(x) return x # 構(gòu)建ResNet網(wǎng)絡(luò) def ResNet(input_shape, num_classes): input_img = layers.Input(shape=input_shape) # 第一個(gè)卷積層 x = layers.Conv2D(64, kernel_size=7, strides=2, padding='same')(input_img) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x) # 堆疊殘差塊組成網(wǎng)絡(luò) x = residual_block(x, filters=64) x = residual_block(x, filters=64) x = residual_block(x, filters=64) x = residual_block(x, filters=128, downsample=True) x = residual_block(x, filters=128) x = residual_block(x, filters=128) x = residual_block(x, filters=256, downsample=True) x = residual_block(x, filters=256) x = residual_block(x, filters=256) x = residual_block(x, filters=512, downsample=True) x = residual_block(x, filters=512) x = residual_block(x, filters=512) # 全局平均池化 x = layers.GlobalAveragePooling2D()(x) # 全連接層輸出 x = layers.Dense(num_classes, activation='softmax')(x) # 創(chuàng)建模型 model = models.Model(inputs=input_img, outputs=x) return model # 在這里定義輸入圖像的形狀和類別數(shù) input_shape = (224, 224, 3) num_classes = 1000 # 構(gòu)建ResNet模型 model = ResNet(input_shape, num_classes) model.summary()
請(qǐng)注意,上述代碼是一個(gè)簡化版本的ResNet網(wǎng)絡(luò),實(shí)際上,ResNet有不同的變體,可以根據(jù)任務(wù)的復(fù)雜性和資源的可用性選擇適合的ResNet結(jié)構(gòu)。