首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667

(部分)代码分为:

1 train_net.py

 1 #import some module
 2 import time
 3 import os
 4 import numpy as np
 5 import sys
 6 import cv2
 7 sys.path.append("/home/wang/Downloads/caffe-master/python")
 8 import caffe
 9 #from prepare_data import DataConfig
10 #from data_config import DataConfig
11
12 #configure GPU mode
13 ''' uncommend below line to use gpu '''
14 caffe.set_mode_gpu()
15
16 # about dataset
17 ##dataset = Dataset('/home/wang/Downloads/object/extract/')
18 ##dataset = dataset.Split('train')
19 ##data_config = DataConfig(dataset)
20 ##data_config.SetBatchSize(256)
21 data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/'
22
23
24
25 #configure solve.prototxt
26 solver = caffe.SGDSolver('models/solver.prototxt')
27
28 # load pretrain model
29 print('load pretrain model')
30 solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel')
31
32 solver.net.layers[0].SetDataConfig(data_config)
33
34 for i in range(1, 10000):
35     # Make one SGD update
36     solver.step(5)
37     if i % 100 == 0:
38         solver.net.save('tmp.caffemodel')
39         ''' TODO:  test code '''  

2 test_net.py

 1 #import setup
 2 import time
 3 import os
 4 import random
 5 import sys
 6 sys.path.append("/home/wang/Downloads/caffe-master/python")
 7 import caffe
 8 import cv2
 9 import numpy as np
10 import random
11
12
13 from utils import PrepareImage
14 #from dataset import Dataset
15 from test_data import test_data_pre
16
17 test_num_once=10
18
19
20 ''' uncommend below line to use gpu '''
21 # caffe.set_mode_gpu()
22
23 # dataset
24 #dataset = Dataset('/home/wang/Downloads/object/extract/')
25 #dataset = dataset.Split('test')
26
27 # load net
28 net = caffe.Net('models/deploy.prototxt', caffe.TEST)
29
30
31 # load train model
32 print('load pretrain model')
33 net.copy_from('tmp.caffemodel')
34
35 #test all samples one by one
36 data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'
37 #(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]
38 (imgPaths, gt_label)=test_data_pre(data_pre)
39 num_img = len(imgPaths)
40 correct_num=0
41 for idx in range(num_img):
42     img = cv2.imread(imgPaths[idx])
43     img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
44     tmp_img = img.copy() # for display
45     img = PrepareImage(img, (227, 227))
46     net.blobs['data'].reshape(test_num_once, 3, 227, 227)
47     net.blobs['data'].data[...] = img
48     #net.blobs['data'].data[i,:,:,:] = img
49     net.forward()
50     score = net.blobs['cls_prob'].data
51     if score.argmax()==gt_label[idx]:
52         correct_num=correct_num+1
53     if idx%100==0:
54         print("Please wait some minutes...")
55 correct_rate=correct_num*1.0/num_img
56 print('The correct rate is :',correct_rate)
57
58
59     

3 test_data.py

 1 import os
 2 import numpy as np
 3 from random import randint
 4 import cv2
 5 from utils import PrepareImage,CatImage
 6 #class data:
 7 #path should be /home/
 8 def test_data_pre(path):
 9     img_list=[]
10     image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))
11     label = np.zeros(image_num, dtype=np.float32)
12
13     i=0
14     for idf in range(3):
15         idf_str=str(idf)
16         path1=path+idf_str
17         tmp_path=os.listdir(path1)
18         for idi in range(len(tmp_path)):
19             img_path=path1+'/'+tmp_path[idi]
20             img_list.append(img_path)
21             label[i]=idf
22             i=i+1
23     return ( img_list,label)

4 pre_data.py

 1 import os
 2 import numpy as np
 3 from random import randint
 4 import cv2
 5 from utils import PrepareImage,CatImage
 6 #class data:
 7 #path should be /home/
 8 def prepare_data(path,batchsize):
 9     #tmp_path=os.listdir(path)
10     img_list=[]
11     label = np.zeros(batchsize, dtype=np.float32)
12     for i in range(batchsize):
13         #randomly select one file
14         idf=randint(0,2)
15         idf_str=str(idf)
16         path1=path+idf_str
17         tmp_path=os.listdir(path1)
18
19         #randomly select one image
20         idi=randint(0,len(tmp_path)-1)
21         #img = cv2.imread(imgPaths[idx])
22         img_path=path1+'/'+tmp_path[idi]
23         img=cv2.imread(img_path)
24
25         img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
26         flip = randint(0, 1)>0
27         if flip > 0:
28             img = img[:, ::-1, :] # flip left to right
29
30         img=PrepareImage(img, (227,227))
31         img_list.append(img)
32         label[i]=idf
33     imgData = CatImage(img_list)
34     return (imgData,label)

5 utils.py

 1 import os
 2 import cv2
 3 import numpy as np
 4
 5 def PrepareImage(im, size):
 6     im = cv2.resize(im, (size[0], size[1]))
 7     im = im.transpose(2, 0, 1)
 8     im = im.astype(np.float32, copy=False)
 9     return im
10
11 def CatImage(im_list):
12     max_shape = np.array([im.shape for im in im_list]).max(axis=0)
13     blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
14     # set to mean value
15     blob[:, 0, :, :] = 102.9801
16     blob[:, 1, :, :] = 115.9465
17     blob[:, 2, :, :] = 122.7717
18     for i, im in enumerate(im_list):
19         blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
20     return blob

6 layer/data_layer.py

 1 import caffe
 2 import numpy as np
 3
 4 #import data_config
 5 #import prepare_data
 6 from pre_data import prepare_data
 7
 8 class DataLayer(caffe.Layer):
 9
10     def SetDataConfig(self, data_config):
11         self._data_config = data_config
12
13     def GetDataConfig(self):
14         return self._data_config
15
16     def setup(self, bottom, top):
17         # data blob
18         top[0].reshape(1, 3, 227, 227)
19         #top[0].reshape(1, 3, 34, 44)
20         # label type
21         top[1].reshape(1, 1)
22
23     def reshape(self, bootom, top):
24         pass
25
26     def forward(self, bottom, top):
27         #(imgs, label) = self._data_config.next()
28         path=self.GetDataConfig()
29         (imgs,label)=prepare_data(path,128)
30         (N, C, W, H) = imgs.shape
31         # image data
32         top[0].reshape(N, C, W, H)
33         top[0].data[...] = imgs
34         # object type label
35         top[1].reshape(N)
36         top[1].data[...] = label
37
38     def backward(self, top, propagate_down, bottom):
39         pass

7 layer/__init__.py

import data_layer

还有一些caffe中经典的东西没放进来。

代码和数据:

转载于:https://www.cnblogs.com/Wanggcong/p/5169737.html

python caffe 在师兄的代码上修改成自己风格的代码相关推荐

  1. python阿拉伯数字转中文_python中将阿拉伯数字转换成中文的实现代码

    #!/usr/bin/python #-*- encoding: utf-8 -*- import types class NotIntegerError(Exception): pass class ...

  2. php代码怎么修改成laravel,Laravel框架实现即点即改功能的方法分析

    本文实例讲述了Laravel框架实现即点即改功能的方法.分享给大家供大家参考,具体如下: 有的时候我们不需要更改大量数据,只需要更改一个字段的时候,我们就用到了即点即改,以用户模块,修改用户名称为例, ...

  3. python主题更改_IDLE怎么将主题修改成Darcula样式?

    摘要:每个人都有自己心中理想的编辑器主题,我更倾向于Darcula,你们呢? 想必没用过Darcula主题的朋友,会好奇它是何方神圣? 是不是很赏心悦目,代码这冰冷的东西也变得生龙活虎? 我最近在使用 ...

  4. STM32WB55 在BLE_HeartRateFreeRTOS例程基础上修改成带rtos的p2ps透传服务

    STEP1\ 通过对比p2ps和HeartRateFreeRTOS例程增删文件 BLE_HeartRateFreeRTOS例程原目录 替换成 p2p_server_app.c修改内容 1. 增删文件 ...

  5. 传奇脚本显示服务器开区时间代码,上百种开区脚本代码详细介绍以及脚本示例...

    变量名必须大写: 通用变量: ------------------------- $SERVERNAME //服务器名称 $SERVERIP //服务器IP $WEBSITE //网站 在String ...

  6. ecshop使用php代码,ecshop 修改模板可输出php代码

    JSON 之 SuperObject(8): 关于乱码的几种情况 - 向 Henri Gourvest 大师报告 这几天学习 JSON - SuperObject, 非常幸运地得到了其作者 Henri ...

  7. win11开始菜单怎么修改成win10风格 Windows11开始菜单修改成win10右键风格的设置方法

    有很多朋友升级到win11系统之后不是特别喜欢右键菜单,因为经常需要多点击一次显示更多选项,很不舒服.大家就想知道如何修改回原来win10的右键菜单,其实还是有方法的,除了使用软件以外,今天小编就来给 ...

  8. 【代码质量管理工具】--使用sonarLint提高代码质量

    [背景] 项目组,在12月份的时候使用了强大的代码质量管理工具--Sonar来检测规范我们的代码,但是当时使用的时候会有一个缺陷,就是我每使用一次maven命令将代码扫描到sonarqube的网页端, ...

  9. matlab如何输出总位移,加速度转换成位移的matlab代码及说明

    <加速度转换成位移的matlab代码及说明>由会员分享,可在线阅读,更多相关<加速度转换成位移的matlab代码及说明(5页珍藏版)>请在人人文库网上搜索. 1.加速度转换成位 ...

最新文章

  1. oracle完全卸載,Oracle10g的完全卸載
  2. React Native JSBundle拆包之原理篇
  3. 16 导出pcb各网络的布线长度_PCB原理图常见错误分析
  4. All cached global options setting for WordPress
  5. 01背包初始化的细节问题与循环下限的改进
  6. c++二进制转十进制_二进制,八进制,十进制,十六进制转换详解~
  7. mysql备份:一,Xtrabackup
  8. 设计模式:观察者模式 ——— 城管来了,摊主快跑
  9. Exchange Server 2016 独立部署/共存部署 (一)—— 前期准备
  10. ZDI 公布2020年 Pwn2Own 东京赛规则和奖金
  11. layui 时间选择器 laydate 设置了默认值时 无法清空
  12. HDMI调试基本原理
  13. 谈谈IT行业的各种证书
  14. abp moveto mysql_abp 使用 hangfire结合mysql
  15. 微信隐藏功能盘点:修复聊天记录
  16. 成都1008 hdu4038
  17. 大数据世界中的新技术
  18. XP系统启动时滚动条总是时间很长
  19. leetcode——【猫和老鼠】
  20. Java Web GIS 地理信息系统开发

热门文章

  1. Spark2.1.0分布式集群安装
  2. CommonsMultipartFile 转为 File 类型
  3. vue-router使用入门
  4. 设计模式10——flyweight模式
  5. 线程状态以及sleep yield wait join方法
  6. JavaScript权威指南(第六版) 初读笔记
  7. Delphi XE5 常见问题解答
  8. (转)oracle extent
  9. 编译原理练习题(第二章)
  10. html邮件模板编辑器_免费电子邮件群发工具推荐「aweber」