Are there any resnet implementations in tensorflow? I came across a few (e.g. https://github.com/ry/tensorflow-resnet, https://github.com/xuyuwei/resnet-tf) but these implementations have some bugs (e.g. see the Issues section on the respective github page). I am looking to train imagenet using resnet and looking for tensorflow implementations.
There are some (50/101/152) in tensorflow:models/slim.
The example notebook shows how to get a pre-trained inception
running, res-net
is probably no different.
I implemented a cifar10 version of ResNet with tensorflow. The validation errors of ResNet-32, ResNet-56 and ResNet-110 are 6.7%, 6.5% and 6.2% respectively. (You can modify the number of layers easily as hyper-parameters.)
I tried to be friendly with new ResNet fan and wrote everything straightforward. You can run the cifar10_train.py file directly without any downloads.
Please find the code below for custom implementation of Resnet34. You can use this model and build you image classification model.
#Dependencies
import tensorflow as tf
from keras.models import Model
from keras.layers import GlobalAveragePooling2D, Dense, Layer, MaxPooling2D, Activation, Conv2D, Add, BatchNormalization
CONFIGURATIONS = {
"NUM_CLASSES" : 3
}
# Custom Class Definition inherited from Model class
class Resnet34(Model):
def __init__(self,):
super(Resnet34, self).__init__(name="resnet_34")
self.conv_1 = CustomConv2D(64, 7, 2, padding="same")
self.max_pool = MaxPooling2D(3, 2)
self.conv_2_1 = ResidualBlock(64)
self.conv_2_2 = ResidualBlock(64)
self.conv_2_3 = ResidualBlock(64)
self.conv_3_1 = ResidualBlock(128, 2) #2 for downsampling
self.conv_3_2 = ResidualBlock(128)
self.conv_3_3 = ResidualBlock(128)
self.conv_3_4 = ResidualBlock(128)
self.conv_4_1 = ResidualBlock(256, 2) # 2 for downsampling
self.conv_4_2 = ResidualBlock(256)
self.conv_4_3 = ResidualBlock(256)
self.conv_4_4 = ResidualBlock(256)
self.conv_4_5 = ResidualBlock(256)
self.conv_4_6 = ResidualBlock(256)
self.conv_5_1 = ResidualBlock(512, 2) # 2 for downsampling
self.conv_5_2 = ResidualBlock(512)
self.conv_5_3 = ResidualBlock(512)
self.global_pool = GlobalAveragePooling2D()
self.fc_3 = Dense(CONFIGURATIONS["NUM_CLASSES"], activation="softmax")
def call(self, x, training=True):
x = self.conv_1(x)
x = self.max_pool(x)
x = self.conv_2_1(x, training)
x = self.conv_2_2(x, training)
x = self.conv_2_3(x, training)
x = self.conv_3_1(x, training)
x = self.conv_3_2(x, training)
x = self.conv_3_3(x, training)
x = self.conv_3_4(x, training)
x = self.conv_4_1(x, training)
x = self.conv_4_2(x, training)
x = self.conv_4_3(x, training)
x = self.conv_4_4(x, training)
x = self.conv_4_5(x, training)
x = self.conv_4_6(x, training)
x = self.conv_5_1(x, training)
x = self.conv_5_2(x, training)
x = self.conv_5_3(x, training)
x = self.global_pool(x)
return self.fc_3(x)
# Custom Conv2D Class inherited from layer
class CustomConv2D(Layer):
def __init__(self, n_filters, kernel_size, n_strides, padding="valid"):
super(CustomConv2D, self).__init__(name="custom_conv2D")
self.conv = Conv2D(
filters=n_filters,
kernel_size=kernel_size,
activation="relu",
strides= n_strides,
padding=padding
)
self.batch_norm = BatchNormalization()
def call(self, x, training=True):
x = self.conv(x)
x = self.batch_norm(x, training)
return x
#Custom Residual Block inherited from Layer class
class ResidualBlock(Layer):
def __init__(self, n_channels, n_strides=1):
super(ResidualBlock, self).__init__(name="res_block")
self.dotted = (n_strides!=1)
self.custom_conv_1 = CustomConv2D(n_channels, 3, n_strides, padding="same")
self.custom_conv_2 = CustomConv2D(n_channels, 3, 1, padding="same")
self.activation = Activation('relu')
if self.dotted:
self.custom_conv_3 = CustomConv2D(n_channels, 1, n_strides) # 1 X 1 Conv layer
def call(self, input, training):
x = self.custom_conv_1(input, training)
x = self.custom_conv_2(x)
if self.dotted:
x_add = self.custom_conv_3(input, training)
x_add = Add()([x, x_add])
else:
x_add = Add()([x, input])
return self.activation(x_add)
#Calling with default build
resnet_34 = Resnet34()
resnet_34(tf.zeros([1, 256,256, 3]), training=False)
resnet_34.summary()
I implemented Resnet by use of ronnie.ai and keras. Both of tool are great. While ronnie is more easy from scratch.
© 2022 - 2024 — McMap. All rights reserved.