爬虫部分

用爬虫从百度获取关键词的图片,进行简单的人工删选后作为训练用的原始数据。爬虫代码可从网上方便找到。

# -*- coding:utf-8 -*-import reimport requestsimport osword = 'something'download_dir = './pic/'#创建文件夹def mdir():if not os.path.exists(download_dir):os.mkdir(download_dir)if not os.path.exists(download_dir+word):os.mkdir(download_dir+word)os.chdir(download_dir+word)   #获取图片的地址def geturl(page_number):url = "http://image.baidu.com/search/avatarjson?tn=resultjsonavatarnew&ie=utf-8&word=" + word + "&cg=girl&rn=60&pn=" +str(page_number)html = requests.get(url).textpic_url = re.findall('"objURL":"(.*?)",',html,re.S)return pic_url#下载图片def downimage(pic_url, page_number):i = 1for each in pic_url:print(each)try:pic= requests.get(each, timeout=5)except requests.exceptions.ConnectionError:print('========Error!CAN NOT DOWMLOAD THIS PICTURE!!!========')i -= 1continueexcept requests.exceptions.Timeout:print("========REQUESTTIMEOUT !!!========")i -= 1string = word+ str(page_number) + '-' + str(i) + '.jpg'fp = open(string,'wb')fp.write(pic.content)fp.close()i += 1   if __name__ == '__main__':mdir()page_number = 10page_count = 60while True:mPicUrl = geturl(page_count)downimage(mPicUrl, page_number)page_count += 60page_number += 1

数据处理部分

将下载好的图片进行批处理,先把图片大小转化为神经网络输入的尺寸(这里是(64,64,3)),再生成pkl文件(当然也可以用其他格式,如tfrecord)。

# -*- coding:utf-8 -*-import PIL.Image as Imageimport numpy as npimport randomimport pickleimport os#改变图像大小为(64,64)new_x, new_y = 64, 64def resize(folders_path):folders = os.listdir(folders_path)for folder in folders:files = os.listdir(folders_path +folder)i=0for file in files:#读取图片,PIL.Image库没有close方法,会导致运行出错,因此用内置的openwith open(folders_path+ folder + '/' + file ,'rb') as f:img = Image.open(f).convert('RGB')try:resized =img.resize((new_x, new_y), resample=Image.LANCZOS)resized.save(folders_path +folder + '/resize-' + str(i)+'.jpg', format="jpeg")except:print('Get ERROR in '+folders_path +folder + '/' + file)#删除原图os.remove(folders_path + folder + '/' + file)i += 1#函数调用:生成数据集def initPKL(imgSet_shuffle, train_or_test):imgSet = []labels = []label_names = []if train_or_test == 'train':set_name = 'trainSet.pkl'else:set_name = 'testSet.pkl'for i in imgSet_shuffle:imgSet.append(i[0])labels.append(i[1])label_names.append(i[2])imgSet = np.array(imgSet)labels = np.array(labels)label_names = np.array(label_names)arr = (imgSet,labels,label_names)#写入文件data = (arr[0],arr[1],arr[2])output = open(set_name, 'wb')pickle.dump(data, output)output.close()def initArr(folders_path):i = 0imgSet = []folders = os.listdir(folders_path)for folder in folders:#类别个数,几个0代表几类label = [0,0,0]files = os.listdir(folders_path +folder)label[i] = 1for file in files:#读取图片img_arr = np.array(Image.open(folders_path+ folder + '/' + file)) / 255imgSet.append((img_arr, label,folder))i += 1return imgSet#将图片转换成数组train_folders_path= './train/'test_folders_path= './test/'resize(train_folders_path)resize(test_folders_path)train_imgSet =initArr(train_folders_path)test_imgSet =initArr(test_folders_path)#打乱顺序random.shuffle(train_imgSet)random.shuffle(test_imgSet)train_set_shuffle= np.array(train_imgSet)test_set_shuffle =np.array(test_imgSet)# 分别生成训练集和测试集initPKL(train_set_shuffle,'train')initPKL(test_set_shuffle,'test')#测试生成的数据集f = open('./trainSet.pkl', 'rb')x, y, z =pickle.load(f)f.close()print(np.shape(x[3]), y[3], z[3])

深度学习部分

这里用比较简单的AlexNet作为例子,构建了一个小型的神经网络,要注意输入和输出的大小((64,64,3)和3)。除了训练,代码还增加了断点保存,模型加载,预判预测等功能。

# -*- coding:utf-8 -*-import numpy as npimport tflearnfrom tflearn.data_utils import shufflefrom tflearn.layers.core import input_data, dropout, fully_connectedfrom tflearn.layers.normalization importlocal_response_normalizationfrom tflearn.layers.conv import conv_2d, max_pool_2dfrom tflearn.layers.estimatorimport regressionfrom tflearn.data_preprocessing importImagePreprocessingfrom tflearn.data_augmentation import ImageAugmentationimport timeimport pickle_loadModel = False_trainModel = True#False# 加载数据集X, Y, Y_name=pickle.load(open("trainSet.pkl", "rb"))X_test, Y_test,Y_test_name =pickle.load(open("testSet.pkl", "rb"))# 打乱数据X, Y, Y_name=shuffle(X, Y, Y_name)#数据处理img_prep =ImagePreprocessing()img_prep.add_featurewise_zero_center()img_prep.add_featurewise_stdnorm()# 翻转、旋转和模糊效果数据集中的图片,来创造一些合成训练数据.img_aug =ImageAugmentation()img_aug.add_random_flip_leftright()img_aug.add_random_flip_updown()img_aug.add_random_rotation(max_angle=25.)img_aug.add_random_blur(sigma_max=3.)# Building'AlexNet',注意输入维数network =input_data(shape=[None, 64, 64, 3])network =conv_2d(network, 96, 11, strides=4, activation='relu')network =max_pool_2d(network, 3, strides=2)network =local_response_normalization(network)network =conv_2d(network, 256, 5, activation='relu')network =max_pool_2d(network, 3, strides=2)network =local_response_normalization(network)network =conv_2d(network, 384, 3, activation='relu')network =conv_2d(network, 384, 3, activation='relu')network =conv_2d(network, 256, 3, activation='relu')network =max_pool_2d(network, 3, strides=2)network =local_response_normalization(network)network =fully_connected(network, 4096, restore=True, activation='tanh')network =dropout(network, 0.5)network =fully_connected(network, 4096, restore=True, activation='tanh')network =dropout(network, 0.5)network =fully_connected(network, 3, restore=True, activation='softmax')network =regression(network,optimizer='adam',loss='categorical_crossentropy',learning_rate=0.0001)# 把网络打包为一个模型对象model =tflearn.DNN(network,tensorboard_verbose=3,tensorboard_dir = './logs/',best_checkpoint_path = './best_checkpoint/best_classifier.tfl.ckpt')if _loadModel == True:model.load("./modelSaved/my_model.tflearn")tic = time.time()poss = model.predict(X_test)toc = time.time()print('预测所用时间:%.3fms' % (1000*(toc-tic)))print('有%.2f%%的概率是**,有%.2f%%的概率是**,有%.2f%%的概率是**' % (100*poss[0][0], 100*poss[0][1], 100*poss[0][2]))print('实际为'+str(Y_test_name[0]))if _trainModel == True:model.fit(X, Y,n_epoch=1000,shuffle=True,validation_set=0.1,#对训练数据执行数据分割,10%用于验证show_metric=True,batch_size=64,snapshot_step=200,snapshot_epoch=False,run_id='AlexNet')model.save("modelSaved/my_model.tflearn")print("Networktrained and saved as my_model.tflearn !")

训练结果

在命令行输入命令如:

tensorboard --logdir=classifier\logs\AlexNet

启动tensorboard查看训练结果

计算图:

训练后期,训练准确率在93~95%之间,验证准确率约可以保持在98%左右。

代码运行结束后(或强行终止后),修改代码:

_loadModel = True

_trainModel = False

使其载入训练好的模型,对测试图片(一张)进行预测。预测结果:

用tflearn构建分类器相关推荐

  1. ML之分类预测之LARS:利用回归工具将二分类转为回归问题并采用LARS算法构建分类器

    ML之分类预测之LARS:利用回归工具将二分类转为回归问题并采用LARS算法构建分类器 目录 输出结果 设计思路 代码实现 输出结果 ['V10', 'V48', 'V44', 'V11', 'V35 ...

  2. 超干货|使用Keras和CNN构建分类器(内含代码和讲解)

    摘要: 为了让文章不那么枯燥,我构建了一个精灵图鉴数据集(Pokedex)这都是一些受欢迎的精灵图.我们在已经准备好的图像数据集上,使用Keras库训练一个卷积神经网络(CNN). 为了让文章不那么枯 ...

  3. 贝叶斯模型构建分类器的设计与实现

    多种贝叶斯模型构建及文本分类的实现 作者:白宁超 2015年9月29日11:10:02 摘要:当前数据挖掘技术使用最为广泛的莫过于文本挖掘领域,包括领域本体构建.短文本实体抽取以及代码的语义级构件方法 ...

  4. TensorFlow图像分类:如何构建分类器

    导言 图像分类对于我们来说是一件非常容易的事情,但是对于一台机器来说,在人工智能和深度学习广泛使用之前,这是一项艰巨的任务.自动驾驶汽车能够实时检测物体并采取相应必要的行动,并且由于TensorFlo ...

  5. 构建商品评价的分类器

    接下来,开始构建分类器: 生成的WordCount是一个字典.键值对的形式 这里的键是某一个单词,对应的值是该单词的个数 图像化查看一下原始数据 ,这里我们取出第一个商品的评价 抽取评价数量最多的商品 ...

  6. nlp-with-transformers系列-02-从头构建文本分类器

    大家好,我是致Great,微信TO-Great,欢迎大家公众号ChallengeHub的NLP技术交流群. 文本分类 文本分类是 NLP 中最常见的任务之一, 它可用于广泛的应用或者开发成程序,例如将 ...

  7. Python构建SVM分类器(线性)

    1.SVM建立线性分类器 SVM用来构建分类器和回归器的监督学习模型,SVM通过对数学方程组的求解,可以找出两组数据之间的最佳分割边界. 2.准备工作 我们首先对数据进行可视化,使用的文件来自学习书籍 ...

  8. tflearn教程_TFlearn 快速入门

    在本教程中,您将学习使用TFLearn和TensorFlow来估计泰坦尼克号乘客使用其个人信息(如性别,年龄等)的幸存机会. 为了解决这个经典机器学习任务,我们要构建一个深层神经网络分类器. 前提条件 ...

  9. PyTorch如何构建和实验神经网络

    点击上方"视学算法",马上关注 真爱,请设置"星标"或点个"在看" 作者 | Tirthajyoti Sarkar 来源 | Medium ...

最新文章

  1. 图形化的Redis监控系统redis-stat安装
  2. 计算机网络技术包括哪几种,计算机网络技术包含的两个主要技术是计算机技术和( )。...
  3. html标准模式与混杂模式,关于Doctype、严格模式与混杂模式
  4. php submit 不要刷新,php实现保存submit内容之后禁止刷新
  5. mongodb输错命令后不能删除问题
  6. MyEclipse for Windows 关于 java、jsp、xml、js、html 等文件的注释快捷键及注释格式介绍
  7. 百度SEO站群腾讯短网址w.url.cn生成源码|仿红源码
  8. 浅谈AI芯片的简要发展历史
  9. 计算机毕业设计 志愿者服务管理系统 志愿者系统 志愿者招募系统 志愿者报名管理系统 志愿者信息管理系统 志愿者管理系统 志愿者管理系统源码 志愿者管理系统java 志愿者信息管理系统
  10. 云原生安全之RASP技术(应用运行时自我保护)
  11. 刀马旦计算机音乐,刀马旦 MIDI File Download :: MidiShow
  12. IOS 清理CALayer、CAShapeLayer的sublayers
  13. 14-vue项目搭建.md
  14. Code jock 8.7 源代码编译
  15. 如何高效学习,学习IT知识(转载)
  16. TCP端口检测、网络连接时延测试工具 tcping
  17. linux7.4邮件服务器,U-Mail邮件服务器For CentOS 7.X独立安装包教程
  18. 深度卷积神经网络演化历史及结构改进脉络-40页长文全面解读
  19. 1、Doherty放大器之宽带拓展理论
  20. 系统管理:Unix 文本编辑

热门文章

  1. 汇编语言 CMP指令
  2. 【已解决】彻底修改Tomcat9 控制台 中文乱码问题
  3. 【行业思考】关于商业模式的本质的思考
  4. linux kernel git clone加速
  5. 如何debug Vue源码
  6. oracle+创建diskgroup,Exadata下新建DiskGroup
  7. Mybatis 主键插入回显
  8. 古代的”太阳“是什么意思
  9. uniappAndroid离线打包 小米审核不通过
  10. 个人博客搭建记录 Hexo+Butterfly+Github Page+Coding