一、熟悉样本:了解Fashion-MNIST数据集

FashionMNIST数据集的单个样本为28pixel*28pixel的灰度图片。训练集有60000张图片,测试集有10000张图片。样本内容为上衣、裤子、鞋子等服饰,一共分为10类。

二、下载Fashion-MNIST数据集

三、代码实现:读取及显示Fashion-MNIST数据集中的数据

from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./fashion/", one_hot=False)
print("输入数据:", mnist.train.images)
print("输入数据的形状:", mnist.train.images.shape)
print("输入数据的标签:", mnist.train.labels)import pylab
im = mnist.train.images[1]
im = im.reshape(-1, 28)
pylab.imshow(im)
pylab.show()

1.在tf.keras接口中读取Fashion_MNIST数据集

import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

四、代码实现:定义胶囊网络模型CapsuleNetModel

import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as npclass CapsuleNetModel:def __init__(self, batch_size, n_classes, iter_routing):self.batch_size = batch_sizeself.n_classes = n_classesself.iter_routing = iter_routing

三、代码实现:实现胶囊网络的基本结构

 def CapsuleNet(self, img):#定义网络模型结构with tf.variable_scope('Conv1_layer') as scope:#定义第一个正常卷积层 ReLU Conv1output = slim.conv2d(img, num_outputs=256, kernel_size=[9, 9], stride=1, padding='VALID', scope=scope)assert output.get_shape() == [self.batch_size, 20, 20, 256]with tf.variable_scope('PrimaryCaps_layer') as scope:#定义主胶囊网络 Primary Capsoutput = slim.conv2d(output, num_outputs=32*8, kernel_size=[9, 9], stride=2, padding='VALID', scope=scope, activation_fn=None)output = tf.reshape(output, [self.batch_size, -1, 1, 8])  #将结果变成32*6*6个胶囊单元,每个单元为8维向量assert output.get_shape() == [self.batch_size, 1152, 1, 8]with tf.variable_scope('DigitCaps_layer') as scope:#定义数字胶囊 Digit Capsu_hats = []input_groups = tf.split(axis=1, num_or_size_splits=1152, value=output)#将输入按照胶囊单元分开for i in range(1152): #遍历每个胶囊单元#利用卷积核为[1,1]的卷积操作,让u与w相乘,再相加得到u_hatone_u_hat = slim.conv2d(input_groups[i], num_outputs=16*10, kernel_size=[1, 1], stride=1, padding='VALID', scope='DigitCaps_layer_w_'+str(i), activation_fn=None)one_u_hat = tf.reshape(one_u_hat, [self.batch_size, 1, 10, 16])#每个胶囊单元变成了16维向量u_hats.append(one_u_hat)u_hat = tf.concat(u_hats, axis=1)#将所有的胶囊单元中的one_u_hat合并起来assert u_hat.get_shape() == [self.batch_size, 1152, 10, 16]#初始化b值b_ijs = tf.constant(np.zeros([1152, 10], dtype=np.float32))v_js = []for r_iter in range(self.iter_routing):#按照指定循环次数,计算动态路由with tf.variable_scope('iter_'+str(r_iter)):c_ijs = tf.nn.softmax(b_ijs, axis=1)  #根据b值,获得耦合系数#将下列变量按照10类分割,每一类单独运算c_ij_groups = tf.split(axis=1, num_or_size_splits=10, value=c_ijs)b_ij_groups = tf.split(axis=1, num_or_size_splits=10, value=b_ijs)u_hat_groups = tf.split(axis=2, num_or_size_splits=10, value=u_hat)for i in range(10):#生成具有跟输入一样尺寸的卷积核[1152, 1],输入为16通道,卷积核个数为1个c_ij = tf.reshape(tf.tile(c_ij_groups[i], [1, 16]), [1152, 1, 16, 1])#利用深度卷积实现u_hat与c矩阵的对应位置相乘,输出的通道数为16*1个s_j = tf.nn.depthwise_conv2d(u_hat_groups[i], c_ij, strides=[1, 1, 1, 1], padding='VALID')assert s_j.get_shape() == [self.batch_size, 1, 1, 16]s_j = tf.reshape(s_j, [self.batch_size, 16])v_j = self.squash(s_j)  #使用squash激活函数,生成最终的输出vjassert v_j.get_shape() == [self.batch_size, 16]#根据vj来计算,并更新b值b_ij_groups[i] = b_ij_groups[i]+tf.reduce_sum(tf.matmul(tf.reshape(u_hat_groups[i], [self.batch_size, 1152, 16]), tf.reshape(v_j, [self.batch_size, 16, 1])), axis=0)if r_iter == self.iter_routing-1:  #迭代结束后,再生成一次vj,得到数字胶囊真正的输出结果v_js.append(tf.reshape(v_j, [self.batch_size, 1, 16]))b_ijs = tf.concat(b_ij_groups, axis=1)#将10类的b合并到一起output = tf.concat(v_js, axis=1)#将10类的vj合并到一起,生成的形状为[self.batch_size, 10, 16]的结果return  outputdef squash(self, s_j):  #定义激活函数s_j_norm_square = tf.reduce_mean(tf.square(s_j), axis=1, keepdims=True)v_j = s_j_norm_square*s_j/((1+s_j_norm_square)*tf.sqrt(s_j_norm_square+1e-9))return v_j

六、代码实现:构建胶囊网络模型

在CapsuleNetModel类中,定义build_model方法来构建胶囊网络模型。具体实现步骤如下。

  1. 将张量图重置
  2. 用CapsuleNet方法构建网络节点
  3. 对CapsuleNet方法返回的结果进行范数计算,得到分类结果self.v_len
  4. 在训练模式下,添加解码器网络,重建输入图片
  5. 实现loss方法,将边距损失与重建损失放在一起,生成总的损失值。
  6. 将损失值放到优化器中,生成张量操作符train_op,用于训练
 def build_model(self, is_train=False,learning_rate = 1e-3):tf.reset_default_graph()#定义占位符self.y = tf.placeholder(tf.float32, [self.batch_size, self.n_classes])self.x = tf.placeholder(tf.float32, [self.batch_size, 28, 28, 1], name='input')#定义计步器self.global_step = tf.Variable(0, name='global_step', trainable=False)initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01)biasInitializer = tf.constant_initializer(0.0)with slim.arg_scope([slim.conv2d], trainable=is_train, weights_initializer=initializer, biases_initializer=biasInitializer):self.v_jsoutput = self.CapsuleNet(self.x) #构建胶囊网络tf.check_numerics(self.v_jsoutput,"self.v_jsoutput is nan ")#判断张量是否为nan 或infwith tf.variable_scope('Masking'):  self.v_len = tf.norm(self.v_jsoutput, axis=2)#计算输出值的欧几里得范数[self.batch_size, 10]if is_train:            #如果是训练模式,重建输入图片masked_v = tf.matmul(self.v_jsoutput, tf.reshape(self.y, [-1, 10, 1]), transpose_a=True)masked_v = tf.reshape(masked_v, [-1, 16])with tf.variable_scope('Decoder'):output = slim.fully_connected(masked_v, 512, trainable=is_train)output = slim.fully_connected(output, 1024, trainable=is_train)self.output = slim.fully_connected(output, 784, trainable=is_train, activation_fn=tf.sigmoid)self.total_loss = self.loss(self.v_len,self.output)#计算loss值#使用退化学习率learning_rate_decay = tf.train.exponential_decay(learning_rate, global_step=self.global_step, decay_steps=2000,decay_rate=0.9)#定义优化器self.train_op = tf.train.AdamOptimizer(learning_rate_decay).minimize(self.total_loss, global_step=self.global_step)#定义保存及恢复模型关键点所使用的saverself.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)def loss(self,v_len, output): #定义loss计算函数max_l = tf.square(tf.maximum(0., 0.9-v_len))max_r = tf.square(tf.maximum(0., v_len - 0.1))l_c = self.y*max_l+0.5 * (1 - self.y) * max_rmargin_loss = tf.reduce_mean(tf.reduce_sum(l_c, axis=1))origin = tf.reshape(self.x, shape=[self.batch_size, -1])reconstruction_err = tf.reduce_mean(tf.square(output-origin))total_loss = margin_loss+0.0005*reconstruction_err#将边距损失与重建损失一起构成lossreturn total_loss

七、代码实现:载入数据集,并训练胶囊网络模型

import tensorflow as tf
import time
import of
import numpy as np
import imageioCapsulemodel = __import__("8-2 Capsulemodel)
CapsuleNetModel = Capsulemodel.CapsuleNetModel#载入数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./fashion/", one_hot=True)def save_images(imgs, size, path):      #定义函数,保存图片imgs = (img + 1.) / 2return (imageio.imwirte(path, mergeImags(imgs, size)))def mergeImgs(images, size):    #定义函数, 合并图片h, w = images.shape[1], images.shape[2]imgs = np.zeros((h * size[0], w * size[1], 3))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]imgs[j * h:j * h + h, i * w:i * w + w, :] = imageimgs[j * h:j * h + h, i * w:i * w + w, :] = imagereturn imgsbatch_size = 128
learning_rate = 1e-3
training_epochs = 5
n_chass = 10
iter_routing =3

八、代码实现:建立会话训练模型

建立会话训练模型是咋main函数中完成操作,具体步骤如下:

  1. 实例化胶囊网络模型类CapsuleNetModel
  2. 建立会话
  3. 在会话中,用循环进行迭代训练
def main(self):#实例化模型:capsmodel = CapsuleNetModel(batch_size, n_class, iter_routing)capsmodel.build_model(is_train=True,learning_rate=learning_rate)#构建网络节点os.makedirs('results', exist_ok=True)#创建路径os.makedirs('./model', exist_ok=True)with tf.Session() as sess:  #建立会话sess.run(tf.global_variables_initializer())#载入检查点checkpoint_path = tf.train.latest_checkpoint('./model/')print("checkpoint_path",checkpoint_path)if checkpoint_path !=None:capsmodel.saver.restore(sess, checkpoint_path)history = []for epoch in range(training_epochs):#按照指定次数迭代数据集total_batch = int(mnist.train.num_examples/batch_size)lossvalue= 0for i in range(total_batch):  #遍历数据集batch_x, batch_y = mnist.train.next_batch(batch_size)#取出数据batch_x = np.reshape(batch_x,[batch_size, 28, 28, 1])batch_y = np.asarray(batch_y,dtype=np.float32)tic = time.time()  #计算运行时间_, loss_value = sess.run([capsmodel.train_op, capsmodel.total_loss], feed_dict={capsmodel.x: batch_x,  capsmodel.y: batch_y})lossvalue +=loss_valueif i % 20 == 0:#每训练20次,输出一次结果print(str(i)+'用时:'+str(time.time()-tic)+' loss:',loss_value)cls_result, recon_imgs = sess.run( [capsmodel.v_len, capsmodel.output], feed_dict={capsmodel.x: batch_x,  capsmodel.y: batch_y})imgs = np.reshape(recon_imgs, (batch_size, 28, 28, 1))size = 6save_images(imgs[0:size * size, :], [size, size], 'results/test_%03d.png' % i)#将结果保存为图片#获得分类结果,评估准确率argmax_idx = np.argmax(cls_result,axis= 1)batch_y_idx = np.argmax(batch_y,axis= 1)print(argmax_idx[:3],batch_y_idx[:3])cls_acc = np.mean(np.equal(argmax_idx, batch_y_idx).astype(np.float32))print('正确率 : ' + str(cls_acc * 100)) history.append(lossvalue/total_batch)if lossvalue/total_batch == min(history):ckpt_path = os.path.join('./model', 'model.ckpt')capsmodel.saver.save(sess, ckpt_path, global_step=capsmodel.global_step.eval())#保存检查点print("save model",ckpt_path)print(epoch,lossvalue/total_batch) if __name__ == "__main__":tf.app.run()

实例39:用胶囊网络识别黑白图中的服装图片相关推荐

  1. 【Pytorch神经网络实战案例】08 识别黑白图中的服装图案(Fashion-MNIST)

    1 Fashion-MNIST简介 FashionMNIST 是一个替代 MNIST 手写数字集 的图像数据集. 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供.其涵盖了来自 10 ...

  2. 结合胶囊网络Capsule和图卷积GCN的文章

    结合胶囊网络Capsule和图卷积GCN的文章 一.Capsule Neural Networks for Graph Classification 1.1 文章概要 1.2 实现方法 1.2.1 G ...

  3. Python识别璇玑图中诗的数量

    Python识别璇玑图中诗的数量 一.璇玑图简介 璇玑图的读法有很多,这里我使用七七棋盘格的读法,在璇玑图中分离出一个七七棋盘格,如下表 吏 官 同 流 污 合 玩 痞 悍 蒙 骗 造 假 蛋 鸡 宴 ...

  4. python实现胶囊网络_在TensorFlow中实现胶囊网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 我们都知道,在许多计算机视觉任务中,卷积神经网络(CNN)的性能均 ...

  5. php户型图识别,买房必看!一分钟学会如何识别户型图中隐藏的猫腻

    乐山房产/讯  户型图是房屋的平面空间布局图,简单的讲,就是用平面的方式把房屋的结构.格局以及尺寸画下来,使人更直观的了解房屋的走向布局.买房子有句常用语是"行不行,看户型",可想 ...

  6. origin作图:如何将一列数据及其标准差放在图中,使图片更加美观

    最近在做图,发现origin软件还挺好用的,但是想把图片做的好看非常不容易,所以想把作图过程中遇到的一些有意思的技巧在这里记录一下(第一次做,图可能并不好看). @一.将数据导入origin表格中,并 ...

  7. Siamese Capsule Networks 翻译 (孪生胶囊网络)

    摘要 胶囊网络在事实上的基准计算机视觉数据集(例如MNIST,CIFAR和smallNORB)上显示出令人鼓舞的结果. 虽然,它们尚未在以下任务上进行测试:(1)所检测到的实体固有地具有更复杂的内部表 ...

  8. 有关胶囊网络你所应知道的一切

    作者 | 一轩明月 编辑 | NewBeeNLP 在使用卷积神经网络(CNNs)解决计算机视觉任务的时候,视角的改变(角度.位置.剪应力等等)很大程度上会造成网络表现的剧烈波动,从而限制了模型的泛化能 ...

  9. 胶囊网络为何如此热门?与卷积神经网络相比谁能更胜一筹?

    编译 | AI科技大本营 参与 |  孙士洁 编辑 |  明 明 [AI科技大本营按]胶囊网络是什么?胶囊网络怎么能克服卷积神经网络的缺点和不足?机器学习顾问AurélienGéron发表了自己的看法 ...

最新文章

  1. 第九次作业-测试报告和用户使用手册
  2. 一对多和多对一的关系,用mybatis写
  3. 华南赛区线上比赛安排
  4. vs界面竖线光标变成灰色方块,输入时替代已有字符
  5. 备忘录(scanf和continue)
  6. P3804 【模板】后缀自动机
  7. createmutex创建的锁需要手动关闭句柄吗_你知道吗?汽车的儿童锁居然还能发挥这么大的作用!...
  8. NSZombieEnabled使用
  9. JavaScript基础二
  10. java ognl使用_java框架篇---struts之OGNL详解
  11. iOS开发系列--Objective-C之协议、代码块、分类
  12. 新浪视频播放器站外调用代码
  13. 史上最详细的UE4安装教程(没有之一,就是史上最详细,不服气你来打我呀)
  14. winpe进入linux系统,制作U盘Linux 与WinPE启动
  15. 程序员装b指南(转)
  16. 人员离职it检查_公司软件开发人员离职信_检讨书
  17. Android Status Bar
  18. 进化论VS中性突变理论
  19. Linux下memc-nginx-module模块指令说明+memcached支持的命令
  20. ai自动配音_自媒体免费配音神器,一键生成100条AI配音

热门文章

  1. 阜阳太和中学2021年高考成绩查询,2021年高考落下帷幕 十二年寒窗 一笑而“过”...
  2. 凤阳中学2021高考成绩查询,2021“我为天职师大代言”|走进安徽凤阳中学
  3. springboot毕设项目酷玩平台设计43qgi(java+VUE+Mybatis+Maven+Mysql)
  4. Dataset:Medical Data and Hospital Readmissions医疗数据和医院再入院情况数据集的简介、下载、使用方法之详细攻略
  5. 使用GDIView工具排查GDI对象泄漏导致程序UI界面绘制异常的问题
  6. 算法-差分法-c++
  7. ps拉长psd按钮图层
  8. 人工智能深度学习笔记
  9. 数据库原理之多值依赖
  10. adb命令查看android系统用户userid