python caffe 在师兄的代码上修改成自己风格的代码
首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667
(部分)代码分为:
1 train_net.py
1 #import some module 2 import time 3 import os 4 import numpy as np 5 import sys 6 import cv2 7 sys.path.append("/home/wang/Downloads/caffe-master/python") 8 import caffe 9 #from prepare_data import DataConfig 10 #from data_config import DataConfig 11 12 #configure GPU mode 13 ''' uncommend below line to use gpu ''' 14 caffe.set_mode_gpu() 15 16 # about dataset 17 ##dataset = Dataset('/home/wang/Downloads/object/extract/') 18 ##dataset = dataset.Split('train') 19 ##data_config = DataConfig(dataset) 20 ##data_config.SetBatchSize(256) 21 data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' 22 23 24 25 #configure solve.prototxt 26 solver = caffe.SGDSolver('models/solver.prototxt') 27 28 # load pretrain model 29 print('load pretrain model') 30 solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') 31 32 solver.net.layers[0].SetDataConfig(data_config) 33 34 for i in range(1, 10000): 35 # Make one SGD update 36 solver.step(5) 37 if i % 100 == 0: 38 solver.net.save('tmp.caffemodel') 39 ''' TODO: test code '''
2 test_net.py
1 #import setup 2 import time 3 import os 4 import random 5 import sys 6 sys.path.append("/home/wang/Downloads/caffe-master/python") 7 import caffe 8 import cv2 9 import numpy as np 10 import random 11 12 13 from utils import PrepareImage 14 #from dataset import Dataset 15 from test_data import test_data_pre 16 17 test_num_once=10 18 19 20 ''' uncommend below line to use gpu ''' 21 # caffe.set_mode_gpu() 22 23 # dataset 24 #dataset = Dataset('/home/wang/Downloads/object/extract/') 25 #dataset = dataset.Split('test') 26 27 # load net 28 net = caffe.Net('models/deploy.prototxt', caffe.TEST) 29 30 31 # load train model 32 print('load pretrain model') 33 net.copy_from('tmp.caffemodel') 34 35 #test all samples one by one 36 data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/' 37 #(imgPaths, gt_label) = dataset[int(random.random()*num_obj)] 38 (imgPaths, gt_label)=test_data_pre(data_pre) 39 num_img = len(imgPaths) 40 correct_num=0 41 for idx in range(num_img): 42 img = cv2.imread(imgPaths[idx]) 43 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 44 tmp_img = img.copy() # for display 45 img = PrepareImage(img, (227, 227)) 46 net.blobs['data'].reshape(test_num_once, 3, 227, 227) 47 net.blobs['data'].data[...] = img 48 #net.blobs['data'].data[i,:,:,:] = img 49 net.forward() 50 score = net.blobs['cls_prob'].data 51 if score.argmax()==gt_label[idx]: 52 correct_num=correct_num+1 53 if idx%100==0: 54 print("Please wait some minutes...") 55 correct_rate=correct_num*1.0/num_img 56 print('The correct rate is :',correct_rate) 57 58 59
3 test_data.py
1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def test_data_pre(path): 9 img_list=[] 10 image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2')) 11 label = np.zeros(image_num, dtype=np.float32) 12 13 i=0 14 for idf in range(3): 15 idf_str=str(idf) 16 path1=path+idf_str 17 tmp_path=os.listdir(path1) 18 for idi in range(len(tmp_path)): 19 img_path=path1+'/'+tmp_path[idi] 20 img_list.append(img_path) 21 label[i]=idf 22 i=i+1 23 return ( img_list,label)
4 pre_data.py
1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def prepare_data(path,batchsize): 9 #tmp_path=os.listdir(path) 10 img_list=[] 11 label = np.zeros(batchsize, dtype=np.float32) 12 for i in range(batchsize): 13 #randomly select one file 14 idf=randint(0,2) 15 idf_str=str(idf) 16 path1=path+idf_str 17 tmp_path=os.listdir(path1) 18 19 #randomly select one image 20 idi=randint(0,len(tmp_path)-1) 21 #img = cv2.imread(imgPaths[idx]) 22 img_path=path1+'/'+tmp_path[idi] 23 img=cv2.imread(img_path) 24 25 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 26 flip = randint(0, 1)>0 27 if flip > 0: 28 img = img[:, ::-1, :] # flip left to right 29 30 img=PrepareImage(img, (227,227)) 31 img_list.append(img) 32 label[i]=idf 33 imgData = CatImage(img_list) 34 return (imgData,label)
5 utils.py
1 import os 2 import cv2 3 import numpy as np 4 5 def PrepareImage(im, size): 6 im = cv2.resize(im, (size[0], size[1])) 7 im = im.transpose(2, 0, 1) 8 im = im.astype(np.float32, copy=False) 9 return im 10 11 def CatImage(im_list): 12 max_shape = np.array([im.shape for im in im_list]).max(axis=0) 13 blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32) 14 # set to mean value 15 blob[:, 0, :, :] = 102.9801 16 blob[:, 1, :, :] = 115.9465 17 blob[:, 2, :, :] = 122.7717 18 for i, im in enumerate(im_list): 19 blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im 20 return blob
6 layer/data_layer.py
1 import caffe 2 import numpy as np 3 4 #import data_config 5 #import prepare_data 6 from pre_data import prepare_data 7 8 class DataLayer(caffe.Layer): 9 10 def SetDataConfig(self, data_config): 11 self._data_config = data_config 12 13 def GetDataConfig(self): 14 return self._data_config 15 16 def setup(self, bottom, top): 17 # data blob 18 top[0].reshape(1, 3, 227, 227) 19 #top[0].reshape(1, 3, 34, 44) 20 # label type 21 top[1].reshape(1, 1) 22 23 def reshape(self, bootom, top): 24 pass 25 26 def forward(self, bottom, top): 27 #(imgs, label) = self._data_config.next() 28 path=self.GetDataConfig() 29 (imgs,label)=prepare_data(path,128) 30 (N, C, W, H) = imgs.shape 31 # image data 32 top[0].reshape(N, C, W, H) 33 top[0].data[...] = imgs 34 # object type label 35 top[1].reshape(N) 36 top[1].data[...] = label 37 38 def backward(self, top, propagate_down, bottom): 39 pass
7 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
代码和数据:
转载于:https://www.cnblogs.com/Wanggcong/p/5169737.html
python caffe 在师兄的代码上修改成自己风格的代码相关推荐
- python阿拉伯数字转中文_python中将阿拉伯数字转换成中文的实现代码
#!/usr/bin/python #-*- encoding: utf-8 -*- import types class NotIntegerError(Exception): pass class ...
- php代码怎么修改成laravel,Laravel框架实现即点即改功能的方法分析
本文实例讲述了Laravel框架实现即点即改功能的方法.分享给大家供大家参考,具体如下: 有的时候我们不需要更改大量数据,只需要更改一个字段的时候,我们就用到了即点即改,以用户模块,修改用户名称为例, ...
- python主题更改_IDLE怎么将主题修改成Darcula样式?
摘要:每个人都有自己心中理想的编辑器主题,我更倾向于Darcula,你们呢? 想必没用过Darcula主题的朋友,会好奇它是何方神圣? 是不是很赏心悦目,代码这冰冷的东西也变得生龙活虎? 我最近在使用 ...
- STM32WB55 在BLE_HeartRateFreeRTOS例程基础上修改成带rtos的p2ps透传服务
STEP1\ 通过对比p2ps和HeartRateFreeRTOS例程增删文件 BLE_HeartRateFreeRTOS例程原目录 替换成 p2p_server_app.c修改内容 1. 增删文件 ...
- 传奇脚本显示服务器开区时间代码,上百种开区脚本代码详细介绍以及脚本示例...
变量名必须大写: 通用变量: ------------------------- $SERVERNAME //服务器名称 $SERVERIP //服务器IP $WEBSITE //网站 在String ...
- ecshop使用php代码,ecshop 修改模板可输出php代码
JSON 之 SuperObject(8): 关于乱码的几种情况 - 向 Henri Gourvest 大师报告 这几天学习 JSON - SuperObject, 非常幸运地得到了其作者 Henri ...
- win11开始菜单怎么修改成win10风格 Windows11开始菜单修改成win10右键风格的设置方法
有很多朋友升级到win11系统之后不是特别喜欢右键菜单,因为经常需要多点击一次显示更多选项,很不舒服.大家就想知道如何修改回原来win10的右键菜单,其实还是有方法的,除了使用软件以外,今天小编就来给 ...
- 【代码质量管理工具】--使用sonarLint提高代码质量
[背景] 项目组,在12月份的时候使用了强大的代码质量管理工具--Sonar来检测规范我们的代码,但是当时使用的时候会有一个缺陷,就是我每使用一次maven命令将代码扫描到sonarqube的网页端, ...
- matlab如何输出总位移,加速度转换成位移的matlab代码及说明
<加速度转换成位移的matlab代码及说明>由会员分享,可在线阅读,更多相关<加速度转换成位移的matlab代码及说明(5页珍藏版)>请在人人文库网上搜索. 1.加速度转换成位 ...
最新文章
- oracle完全卸載,Oracle10g的完全卸載
- React Native JSBundle拆包之原理篇
- 16 导出pcb各网络的布线长度_PCB原理图常见错误分析
- All cached global options setting for WordPress
- 01背包初始化的细节问题与循环下限的改进
- c++二进制转十进制_二进制,八进制,十进制,十六进制转换详解~
- mysql备份:一,Xtrabackup
- 设计模式:观察者模式 ——— 城管来了,摊主快跑
- Exchange Server 2016 独立部署/共存部署 (一)—— 前期准备
- ZDI 公布2020年 Pwn2Own 东京赛规则和奖金
- layui 时间选择器 laydate 设置了默认值时 无法清空
- HDMI调试基本原理
- 谈谈IT行业的各种证书
- abp moveto mysql_abp 使用 hangfire结合mysql
- 微信隐藏功能盘点:修复聊天记录
- 成都1008 hdu4038
- 大数据世界中的新技术
- XP系统启动时滚动条总是时间很长
- leetcode——【猫和老鼠】
- Java Web GIS 地理信息系统开发