I am working with code from here (paper here) that creates a GAN. I am trying to apply this to a new domain, switching from their application on MNIST to 3D brain MRI images. My issue is in the defining of the GAN itself.
For example, their code for defining the generative model (takes noise of dimension z_dim and produces an image from the MNIST distribution, so 28x28) is this, with my comments based on how I believe it works:
def generate(self, z):
# start with noise in compact space
assert z.shape[1] == self.z_dim
# Fully connected layer that for some reason expands to latent * 64
output = tflib.ops.linear.Linear('Generator.Input', self.z_dim,
self.latent_dim * 64, z)
output = tf.nn.relu(output)
# Reshape the latent dimension into 4x4 MNIST
output = tf.reshape(output, [-1, self.latent_dim * 4, 4, 4])
# Reduce the latent dimension to get 8x8 MNIST
output = tflib.ops.deconv2d.Deconv2D('Generator.2', self.latent_dim * 4,
self.latent_dim * 2, 5, output)
output = tf.nn.relu(output) # 8 x 8
# To be able to get 28x28 later?
output = output[:, :, :7, :7] # 7 x 7
# Reduce more to get 14x14
output = tflib.ops.deconv2d.Deconv2D('Generator.3', self.latent_dim * 2,
self.latent_dim, 5, output)
output = tf.nn.relu(output) # 14 x 14
output = tflib.ops.deconv2d.Deconv2D('Generator.Output',
self.latent_dim, 1, 5, output)
output = tf.nn.sigmoid(output) # 28 x 28
if self.gen_params is None:
self.gen_params = tflib.params_with_name('Generator')
return tf.reshape(output, [-1, self.x_dim])
And this is my code using niftynet convolutional layers, where z_dim and latent_dim are the same as before at 64, and I've added the results of the print statements:
def generate(self, z):
assert z.shape[1] == self.z_dim
generator_input = FullyConnectedLayer(self.latent_dim * 64,
acti_func='relu',
#with_bn = True,
name='Generator.Input')
output = generator_input(z, is_training=True)
print(output.shape) # (?, 4096)
#output = tflib.ops.linear.Linear('Generator.Input', self.z_dim,
# self.latent_dim * 64, z)
#output = tf.nn.relu(output)
output = tf.reshape(output, [-1, self.latent_dim * 4, 1, 18, 18]) # 4 x 4
print(output.shape) # (?, 256, 1, 18, 18)
generator_2 = DeconvolutionalLayer(self.latent_dim*2,
kernel_size=5,
stride=2,
acti_func='relu',
name='Generator.2')
output = generator_2(output, is_training=True)
#output = tflib.ops.deconv2d.Deconv2D('Generator.2', self.latent_dim * 4,
# self.latent_dim * 2, 5, output)
#output = tf.nn.relu(output) # 8 x 8
print(output.shape) # (?, 512, 2, 36, 128)
#output = output[:, :, :-1, :-1] # 7 x 7
generator_3 = DeconvolutionalLayer(self.latent_dim,
kernel_size=5,
stride=2,
acti_func='relu',
name='Generator.3')
output = generator_3(output, is_training=True)
#output = tflib.ops.deconv2d.Deconv2D('Generator.3', self.latent_dim * 2,
# self.latent_dim, 5, output)
#output = tf.nn.relu(output) # 14 x 14
print(output.shape) # (?, 1024, 4, 72, 64)
generator_out = DeconvolutionalLayer(1,
kernel_size=5,
stride=2,
acti_func='sigmoid',
name='Generator.Output')
output = generator_out(output, is_training=True)
#output = tflib.ops.deconv2d.Deconv2D('Generator.Output',
# self.latent_dim, 1, 5, output)
#output = tf.nn.sigmoid(output) # 28 x 28
if self.gen_params is None:
self.gen_params = tflib.params_with_name('Generator')
print(output.shape) # (?, 2048, 8, 144, 1)
print("Should be %s" % str(self.x_dim)) # [1, 19, 144, 144, 4]
return tf.reshape(output, self.x_dim)
I am not really sure how to be able to get the 19 in there. Currently I get this error.
ValueError: Dimension size must be evenly divisible by 2359296 but is 1575936 for 'Reshape_1' (op: 'Reshape') with input shapes: [?,2048,8,144,1], [5] and with input tensors computed as partial shapes: input1 = [1,19,144,144,4].
I am also relatively new to building NNs and I also have a few questions. What is the point of the latent space when we already have a compact representation in z-space? How do I decide the size of the "output dimension" i.e. the second parameter in the layer constructor?
I have also been looking at a successful implementation of a CNN with here for inspiration. Thank you!
Major edit:
I made some progress and got tensorflow to run the code. However, even with a batch size of 1, I am running into an out of memory error when I try to run the training operation. I calculated one image to have the size 19 * 144 * 144 * 4 * 32 (bits per pixel) = ~50 MB so it is not the data that is causing the memory error. Since I basically just tweaked the GAN parameters until it worked, my issue is probably in there. Below is the whole file.
class MnistWganInv(object):
def __init__(self, x_dim=784, z_dim=64, latent_dim=64, batch_size=80,
c_gp_x=10., lamda=0.1, output_path='./'):
self.x_dim = [-1] + x_dim[1:]
self.z_dim = z_dim
self.latent_dim = latent_dim
self.batch_size = batch_size
self.c_gp_x = c_gp_x
self.lamda = lamda
self.output_path = output_path
self.gen_params = self.dis_params = self.inv_params = None
self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim])
self.x_p = self.generate(self.z)
self.x = tf.placeholder(tf.float32, shape=x_dim)
self.z_p = self.invert(self.x)
self.dis_x = self.discriminate(self.x)
self.dis_x_p = self.discriminate(self.x_p)
self.rec_x = self.generate(self.z_p)
self.rec_z = self.invert(self.x_p)
self.gen_cost = -tf.reduce_mean(self.dis_x_p)
self.inv_cost = tf.reduce_mean(tf.square(self.x - self.rec_x))
self.inv_cost += self.lamda * tf.reduce_mean(tf.square(self.z - self.rec_z))
self.dis_cost = tf.reduce_mean(self.dis_x_p) - tf.reduce_mean(self.dis_x)
alpha = tf.random_uniform(shape=[self.batch_size, 1], minval=0., maxval=1.)
difference = self.x_p - self.x
interpolate = self.x + alpha * difference
gradient = tf.gradients(self.discriminate(interpolate), [interpolate])[0]
slope = tf.sqrt(tf.reduce_sum(tf.square(gradient), axis=1))
gradient_penalty = tf.reduce_mean((slope - 1.) ** 2)
self.dis_cost += self.c_gp_x * gradient_penalty
self.gen_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Generator')
self.inv_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Inverter')
self.dis_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Discriminator')
self.gen_train_op = tf.train.AdamOptimizer(
learning_rate=1e-4, beta1=0.9, beta2=0.999).minimize(
self.gen_cost, var_list=self.gen_params)
self.inv_train_op = tf.train.AdamOptimizer(
learning_rate=1e-4, beta1=0.9, beta2=0.999).minimize(
self.inv_cost, var_list=self.inv_params)
self.dis_train_op = tf.train.AdamOptimizer(
learning_rate=1e-4, beta1=0.9, beta2=0.999).minimize(
self.dis_cost, var_list=self.dis_params)
def generate(self, z):
print(z.shape)
assert z.shape[1] == self.z_dim
with tf.name_scope('Generator.Input') as scope:
generator_input = FullyConnectedLayer(self.latent_dim * 4 * 3 * 18 * 18,
acti_func='relu',
#with_bn = True,
name='Generator.Input')(z, is_training=True)
print(generator_input.shape)
#output = tflib.ops.linear.Linear('Generator.Input', self.z_dim,
# self.latent_dim * 64, z)
#output = tf.nn.relu(output)
generator_input = tf.reshape(generator_input, [-1, 3, 18, 18, self.latent_dim * 4]) # 4 x 4
print(generator_input.shape)
with tf.name_scope('Generator.2') as scope:
generator_2 = DeconvolutionalLayer(self.latent_dim*2,
kernel_size=5,
stride=2,
acti_func='relu',
name='Generator.2')(generator_input, is_training=True)
#output = tflib.ops.deconv2d.Deconv2D('Generator.2', self.latent_dim * 4,
# self.latent_dim * 2, 5, output)
#output = tf.nn.relu(output) # 8 x 8
print(generator_2.shape)
with tf.name_scope('Generator.3') as scope:
generator_3 = DeconvolutionalLayer(self.latent_dim,
kernel_size=5,
stride=2,
acti_func='relu',
name='Generator.3')(generator_2, is_training=True)
#output = tflib.ops.deconv2d.Deconv2D('Generator.3', self.latent_dim * 2,
# self.latent_dim, 5, output)
#output = tf.nn.relu(output) # 14 x 14
print(generator_3.shape)
with tf.name_scope('Generator.Output') as scope:
generator_out = DeconvolutionalLayer(4,
kernel_size=5,
stride=2,
acti_func='sigmoid',
name='Generator.Output')(generator_3, is_training=True)
#output = tflib.ops.deconv2d.Deconv2D('Generator.Output',
# self.latent_dim, 1, 5, output)
#output = tf.nn.sigmoid(output) # 28 x 28
if self.gen_params is None:
self.gen_params = tflib.params_with_name('Generator')
print(generator_out.shape)
generator_out = generator_out[:, :19, :, :, :]
print(generator_out.shape)
print("Should be %s" % str(self.x_dim))
return tf.reshape(generator_out, self.x_dim)
def discriminate(self, x):
input = tf.reshape(x, self.x_dim) # 28 x 28
with tf.name_scope('Discriminator.Input') as scope:
discriminator_input = ConvolutionalLayer(self.latent_dim,
kernel_size=5,
stride=2,
acti_func='leakyrelu',
name='Discriminator.Input')(input, is_training=True)
#output = tflib.ops.conv2d.Conv2D(
# 'Discriminator.Input', 1, self.latent_dim, 5, output, stride=2)
#output = tf.nn.leaky_relu(output) # 14 x 14
with tf.name_scope('Discriminator.2') as scope:
discriminator_2 = ConvolutionalLayer(self.latent_dim*2,
kernel_size=5,
stride=2,
acti_func='leakyrelu',
name='Discriminator.2')(discriminator_input, is_training=True)
#output = tflib.ops.conv2d.Conv2D(
# 'Discriminator.2', self.latent_dim, self.latent_dim * 2, 5,
# output, stride=2)
#output = tf.nn.leaky_relu(output) # 7 x 7
with tf.name_scope('Discriminator.3') as scope:
discriminator_3 = ConvolutionalLayer(self.latent_dim*4,
kernel_size=5,
stride=2,
acti_func='leakyrelu',
name='Discriminator.3')(discriminator_2, is_training=True)
#output = tflib.ops.conv2d.Conv2D(
# 'Discriminator.3', self.latent_dim * 2, self.latent_dim * 4, 5,
# output, stride=2)
#output = tf.nn.leaky_relu(output) # 4 x 4
discriminator_3 = tf.reshape(discriminator_3, [-1, self.latent_dim * 48])
with tf.name_scope('Discriminator.Output') as scope:
discriminator_out = FullyConnectedLayer(1,
name='Discriminator.Output')(discriminator_3, is_training=True)
#output = tflib.ops.linear.Linear(
# 'Discriminator.Output', self.latent_dim * 64, 1, output)
discriminator_out = tf.reshape(discriminator_out, [-1])
if self.dis_params is None:
self.dis_params = tflib.params_with_name('Discriminator')
return discriminator_out
def invert(self, x):
output = tf.reshape(x, self.x_dim) # 28 x 28
with tf.name_scope('Inverter.Input') as scope:
inverter_input = ConvolutionalLayer(self.latent_dim,
kernel_size=5,
stride=2,
#padding='VALID',
#w_initializer=self.initializers['w'],
#w_regularizer=self.regularizers['w'],
#b_initializer=self.initializers['b'],
#b_regularizer=self.regularizers['b'],
acti_func='leakyrelu',
#with_bn = True,
name='Inverter.Input')
#output = tflib.ops.conv2d.Conv2D(
# 'Inverter.Input', 1, self.latent_dim, 5, output, stride=2)
#output = tf.nn.leaky_relu(output) # 14 x 14
output = inverter_input(output, is_training=True)
with tf.name_scope('Inverter.2') as scope:
inverter_2 = ConvolutionalLayer(self.latent_dim*2,
kernel_size=5,
stride=2,
acti_func='leakyrelu',
name='Inverter.2')
output = inverter_2(output, is_training=True)
#output = tflib.ops.conv2d.Conv2D(
# 'Inverter.2', self.latent_dim, self.latent_dim * 2, 5, output,
# stride=2)
#output = tf.nn.leaky_relu(output) # 7 x 7
with tf.name_scope('Inverter.3') as scope:
inverter_3 = ConvolutionalLayer(self.latent_dim*4,
kernel_size=5,
stride=2,
acti_func='leakyrelu',
name='Inverter.3')
output = inverter_3(output, is_training=True)
#output = tflib.ops.conv2d.Conv2D(
# 'Inverter.3', self.latent_dim * 2, self.latent_dim * 4, 5,
# output, stride=2)
#output = tf.nn.leaky_relu(output) # 4 x 4
output = tf.reshape(output, [-1, self.latent_dim * 48])
with tf.name_scope('Inverter.4') as scope:
inverter_4 = FullyConnectedLayer(self.latent_dim*8,
acti_func='leakyrelu',
#with_bn = True,
name='Inverter.4')
output = inverter_4(output, is_training=True)
#output = tflib.ops.linear.Linear(
# 'Inverter.4', self.latent_dim * 64, self.latent_dim * 8, output)
#output = tf.nn.leaky_relu(output)
with tf.name_scope('Inverter.Output') as scope:
inverter_output = FullyConnectedLayer(self.z_dim,
acti_func='leakyrelu',
#with_bn = True,
name='Inverter.Output')
output = inverter_output(output, is_training=True)
#output = tflib.ops.linear.Linear(
# 'Inverter.Output', self.latent_dim * 8, self.z_dim, output)
output = tf.reshape(output, [-1, self.z_dim])
if self.inv_params is None:
self.inv_params = tflib.params_with_name('Inverter')
return output
def train_gen(self, sess, x, z):
_gen_cost, _ = sess.run([self.gen_cost, self.gen_train_op],
feed_dict={self.x: x, self.z: z})
return _gen_cost
def train_dis(self, sess, x, z):
_dis_cost, _ = sess.run([self.dis_cost, self.dis_train_op],
feed_dict={self.x: x, self.z: z})
return _dis_cost
def train_inv(self, sess, x, z):
_inv_cost, _ = sess.run([self.inv_cost, self.inv_train_op],
feed_dict={self.x: x, self.z: z})
return _inv_cost
def generate_from_noise(self, sess, noise, frame):
samples = sess.run(self.x_p, feed_dict={self.z: noise})
for i in range(batch_size):
save_array_as_nifty_volume(samples[i], "examples/img_{0:}.nii.gz".format(n*batch_size + i))
#tflib.save_images.save_images(
# samples.reshape((-1, 28, 28)),
# os.path.join(self.output_path, 'examples/samples_{}.png'.format(frame)))
return samples
def reconstruct_images(self, sess, images, frame):
reconstructions = sess.run(self.rec_x, feed_dict={self.x: images})
comparison = np.zeros((images.shape[0] * 2, images.shape[1]),
dtype=np.float32)
for i in range(images.shape[0]):
comparison[2 * i] = images[i]
comparison[2 * i + 1] = reconstructions[i]
for i in range(batch_size):
save_array_as_nifty_volume(comparison[i], "examples/img_{0:}.nii.gz".format(n*batch_size + i))
#tflib.save_images.save_images(
# comparison.reshape((-1, 28, 28)),
# os.path.join(self.output_path, 'examples/recs_{}.png'.format(frame)))
return comparison
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--z_dim', type=int, default=64, help='dimension of z')
parser.add_argument('--latent_dim', type=int, default=64,
help='latent dimension')
parser.add_argument('--iterations', type=int, default=100000,
help='training steps')
parser.add_argument('--dis_iter', type=int, default=5,
help='discriminator steps')
parser.add_argument('--c_gp_x', type=float, default=10.,
help='coefficient for gradient penalty x')
parser.add_argument('--lamda', type=float, default=.1,
help='coefficient for divergence of z')
parser.add_argument('--output_path', type=str, default='./',
help='output path')
parser.add_argument('-config')
args = parser.parse_args()
config = parse_config(args.config)
config_data = config['data']
print("Loading data...")
# dataset iterator
dataloader = DataLoader(config_data)
dataloader.load_data()
batch_size = config_data['batch_size']
full_data_shape = [batch_size] + config_data['data_shape']
#train_gen, dev_gen, test_gen = tflib.mnist.load(args.batch_size, args.batch_size)
def inf_train_gen():
while True:
train_pair = dataloader.get_subimage_batch()
tempx = train_pair['images']
tempw = train_pair['weights']
tempy = train_pair['labels']
yield tempx, tempw, tempy
#_, _, test_data = tflib.mnist.load_data()
#fixed_images = test_data[0][:32]
#del test_data
tf.set_random_seed(326)
np.random.seed(326)
fixed_noise = np.random.randn(64, args.z_dim)
print("Initializing GAN...")
mnistWganInv = MnistWganInv(
x_dim=full_data_shape, z_dim=args.z_dim, latent_dim=args.latent_dim,
batch_size=batch_size, c_gp_x=args.c_gp_x, lamda=args.lamda,
output_path=args.output_path)
saver = tf.train.Saver(max_to_keep=1000)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
images = noise = gen_cost = dis_cost = inv_cost = None
dis_cost_lst, inv_cost_lst = [], []
print("Starting training...")
for iteration in range(args.iterations):
for i in range(args.dis_iter):
noise = np.random.randn(batch_size, args.z_dim)
images, images_w, images_y = next(inf_train_gen())
dis_cost_lst += [mnistWganInv.train_dis(session, images, noise)]
inv_cost_lst += [mnistWganInv.train_inv(session, images, noise)]
gen_cost = mnistWganInv.train_gen(session, images, noise)
dis_cost = np.mean(dis_cost_lst)
inv_cost = np.mean(inv_cost_lst)
tflib.plot.plot('train gen cost', gen_cost)
tflib.plot.plot('train dis cost', dis_cost)
tflib.plot.plot('train inv cost', inv_cost)
if iteration % 100 == 99:
mnistWganInv.generate_from_noise(session, fixed_noise, iteration)
mnistWganInv.reconstruct_images(session, fixed_images, iteration)
if iteration % 1000 == 999:
save_path = saver.save(session, os.path.join(
args.output_path, 'models/model'), global_step=iteration)
if iteration % 1000 == 999:
dev_dis_cost_lst, dev_inv_cost_lst = [], []
for dev_images, _ in dev_gen():
noise = np.random.randn(batch_size, args.z_dim)
dev_dis_cost, dev_inv_cost = session.run(
[mnistWganInv.dis_cost, mnistWganInv.inv_cost],
feed_dict={mnistWganInv.x: dev_images,
mnistWganInv.z: noise})
dev_dis_cost_lst += [dev_dis_cost]
dev_inv_cost_lst += [dev_inv_cost]
tflib.plot.plot('dev dis cost', np.mean(dev_dis_cost_lst))
tflib.plot.plot('dev inv cost', np.mean(dev_inv_cost_lst))
if iteration < 5 or iteration % 100 == 99:
tflib.plot.flush(os.path.join(args.output_path, 'models'))
tflib.plot.tick()