目标检测在下一篇->

整理数据:

有蛋就是1,没有就是0。
文件夹1里放拍好的鹅蛋图片,文件夹0里放其他杂乱无章的图片,保证图片里没有鹅蛋图片整理好后,使用def get_data_list(target_path,train_list_path,eval_list_path)生成数据列表:

def get_data_list(target_path,train_list_path,eval_list_path):'''生成数据列表'''#存放所有类别的信息class_detail = []#获取所有类别保存的文件夹名称data_list_path=target_path+"egg01/"class_dirs = os.listdir(data_list_path)  #总的图像数量all_class_images = 0#存放类别标签class_label=0#存放类别数目class_dim = 0#存储要写进eval.txt和train.txt中的内容trainer_list=[]eval_list=[]#读取每个类别,for class_dir in class_dirs:if class_dir != ".DS_Store":class_dim += 1#每个类别的信息class_detail_list = {}eval_sum = 0trainer_sum = 0#统计每个类别有多少张图片class_sum = 0#获取类别路径 path = data_list_path  + class_dir#print(path[:32] + path[36:])# 获取所有图片img_paths = os.listdir(path)for img_path in img_paths:                                  # 遍历文件夹下的每个图片name_path = path + '/' + img_path                       # 每张图片的路径if class_sum % 8 == 0:                                  # 每8张图片取一个做验证数据eval_sum += 1                                       # test_sum为测试数据的数目eval_list.append(name_path + "\t%d" % class_label + "\n")else:trainer_sum += 1 trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目class_sum += 1                                          #每类图片的数目all_class_images += 1                                   #所有类图片的数目# 说明的json文件的class_detail数据class_detail_list['class_name'] = class_dir             #类别名称class_detail_list['class_label'] = class_label          #类别标签class_detail_list['class_eval_images'] = eval_sum       #该类数据的测试集数目class_detail_list['class_trainer_images'] = trainer_sum #该类数据的训练集数目class_detail.append(class_detail_list)  #初始化标签列表train_parameters['label_dict'][str(class_label)] = class_dirclass_label += 1 #初始化分类数train_parameters['class_dim'] = class_dim#乱序  random.shuffle(eval_list)with open(eval_list_path, 'a') as f:for eval_image in eval_list:f.write(eval_image) random.shuffle(trainer_list)with open(train_list_path, 'a') as f2:for train_image in trainer_list:f2.write(train_image) # 说明的json文件信息readjson = {}readjson['all_class_name'] = data_list_path                  #文件父目录readjson['all_class_images'] = all_class_imagesreadjson['class_detail'] = class_detailjsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(',', ': '))with open(train_parameters['readme_path'],'w') as f:f.write(jsons)print ('生成数据列表完成!')
'''
参数初始化
'''
src_path=train_parameters['src_path']
target_path=train_parameters['target_path']
train_list_path=train_parameters['train_list_path']
eval_list_path=train_parameters['eval_list_path']

编写dataset类用来获取数据

class dataset(Dataset):def __init__(self, data_path, mode='train'):"""数据读取器:param data_path: 数据集所在路径:param mode: train or eval"""super().__init__()self.data_path = data_pathself.img_paths = []self.labels = []if mode == 'train':with open(os.path.join(self.data_path, "train.txt"), "r", encoding="utf-8") as f:self.info = f.readlines()for img_info in self.info:img_path, label = img_info.strip().split('\t')self.img_paths.append(img_path)self.labels.append(int(label))else:with open(os.path.join(self.data_path, "eval.txt"), "r", encoding="utf-8") as f:self.info = f.readlines()for img_info in self.info:img_path, label = img_info.strip().split('\t')self.img_paths.append(img_path)self.labels.append(int(label))def __getitem__(self, index):"""获取一组数据:param index: 文件索引号:return:"""# 第一步打开图像文件并获取label值img_path = self.img_paths[index]img = Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB') img = img.resize((224, 224), Image.BILINEAR)img = np.array(img).astype('float32')img = img.transpose((2, 0, 1)) / 255label = self.labels[index]label = np.array([label], dtype="int64")return img, labeldef print_sample(self, index: int = 0):print("文件名", self.img_paths[index], "\t标签值", self.labels[index])def __len__(self):return len(self.img_paths)
#训练数据加载
train_dataset = dataset('/home/aistudio/work',mode='train')
train_loader = paddle.io.DataLoader(train_dataset, batch_size=16, shuffle=True)
#测试数据加载
eval_dataset = dataset('/home/aistudio/work',mode='eval')
eval_loader = paddle.io.DataLoader(eval_dataset, batch_size = 8, shuffle=False)

测试数据文件

train_dataset.print_sample(200)
print(train_dataset.__len__())
eval_dataset.print_sample(0)
print(eval_dataset.__len__())
print(eval_dataset.__getitem__(10)[0].shape)
print(eval_dataset.__getitem__(10)[1].shape)

模型搭建:

我这里使用的是PaddlePaddle预训练模型resnet152,当然你也可以使用其他的

#定义模型
class MyNet(paddle.nn.Layer):def __init__(self):super(MyNet,self).__init__()self.layer=paddle.vision.models.resnet152(pretrained=True)#'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', #'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet'self.fc = paddle.nn.Linear(1000, 400)self.relu = paddle.nn.ReLU()self.fc1 = paddle.nn.Linear(400, 20)#网络的前向计算过程def forward(self, x, label=None):x=self.layer(x)x = self.relu(self.fc(x))x=self.fc1(x)if label is not None:acc = paddle.metric.accuracy(input=x, label=label)return x, accelse:return x
#画图
def draw_process(title,color,iters,data,label):plt.title(title, fontsize=24)plt.xlabel("iter", fontsize=20)plt.ylabel(label, fontsize=20)plt.plot(iters, data,color=color,label=label) plt.legend()plt.grid()plt.show()

模型训练:

model =MyNet()
use_gpu = True
#paddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')
device = paddle.set_device('gpu')
model.train()
cross_entropy = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=train_parameters['learning_strategy']['lr'],parameters=model.parameters()) steps = 0
Iters, total_loss, total_acc = [], [], []for epo in range(train_parameters['num_epochs']):for _, data in enumerate(train_loader()):steps += 1x_data = data[0]y_data = data[1]predicts, acc = model(x_data, y_data)loss = cross_entropy(predicts, y_data)loss.backward()optimizer.step()optimizer.clear_grad()if steps % train_parameters["skip_steps"] == 0:Iters.append(steps)total_loss.append(loss.numpy()[0])total_acc.append(acc.numpy()[0])#打印中间过程print('epo: {}, step: {}, loss is: {}, acc is: {}'\.format(epo, steps, loss.numpy(), acc.numpy()))#保存模型参数if steps % train_parameters["save_steps"] == 0:save_path = train_parameters["checkpoints"]+"/"+"save_dir_" + str(steps) + '.pdparams'print('save model to: ' + save_path)paddle.save(model.state_dict(),save_path)
paddle.save(model.state_dict(),train_parameters["checkpoints"]+"/"+"save_dir_final.pdparams")
draw_process("trainning loss","red",Iters,total_loss,"trainning loss")
draw_process("trainning acc","green",Iters,total_acc,"trainning acc")

模型评估:

model__state_dict = paddle.load('work/checkpoints/save_dir_final.pdparams')
model_eval = MyNet()
model_eval.set_state_dict(model__state_dict)
model_eval.eval()
accs = []for _, data in enumerate(eval_loader()):x_data = data[0]y_data = data[1]predicts = model_eval(x_data)acc = paddle.metric.accuracy(predicts, y_data)accs.append(acc.numpy()[0])
print('模型在验证集上的准确率为:',np.mean(accs))

模型预测:

def load_image(img_path):'''预测图片预处理'''img = Image.open(img_path) if img.mode != 'RGB': img = img.convert('RGB') img = img.resize((224, 224), Image.BILINEAR)img = np.array(img).astype('float32') img = img.transpose((2, 0, 1)) / 255 # HWC to CHW 及归一化return imglabel_dic = train_parameters['label_dict']
model__state_dict = paddle.load('work/checkpoints/save_dir_final.pdparams')
model_predict = MyNet()
model_predict.set_state_dict(model__state_dict)
model_predict.eval()
infer_dst_path = '/home/aistudio/data/'
infer_imgs_path = os.listdir(infer_dst_path+"test")
for infer_img_path in infer_imgs_path:infer_img = load_image(infer_dst_path+"test/"+infer_img_path)infer_img = infer_img[np.newaxis,:, : ,:]  #reshape(-1,3,224,224)infer_img = paddle.to_tensor(infer_img)result = model_predict(infer_img)lab = np.argmax(result.numpy())if '.' in infer_img_path[:2]:infer_img_path='0'+infer_img_path[:1]+infer_img_path[1:]with open("work/result.txt", "a") as f:f.write("{}\n".format(infer_img_path +' '+label_dic[str(lab)]))

当然你也可以把结果排序一下

f=open('work/result.txt')
result= []
iter_f=iter(f)      #用迭代器循环访问文件中的每一行
for line in iter_f:result.append(line)
f.close()
result.sort()
f=open('work/result.txt','w')
f.writelines(result)
f.close()

总结:

数据集太小,样本数量太少,正样本图片过于相像,导致过拟合严重
在测试集上准确率为 100%(震惊!!)
但是验证集上效果并不是很好

按照分类方法判断图片里是否有鹅蛋相关推荐

  1. java判断图片是否被修改过_4种方法判断图片是否被PS处理过,你都会吗? | X的博客...

    "你用什么牌子的化妆品?" "Photoshop" 你是不是经常在网上看到新闻里说某官员的艳照系PS处理过,看到漂亮妹子照片又怀疑是Photoshop处理过?本 ...

  2. Opencv图像处理:判断图片里某个颜色值占的比例

    一.功能 这里的需求是,判断摄像头有没有被物体遮挡.这里只考虑用手遮挡---->判断黑色颜色的范围. 二.使用OpenCV的Mat格式图片遍历图片 下面代码里,传入的图片的尺寸是640*480, ...

  3. javascript判断图片是否加载完成方法整理

    有时候我们在前端开发工作中为了获取图片的信息,需要在图片加载完成后才可以正确的获取到图片的大小尺寸,并且执行相应的回调函数使图片产生某种显示效果.本文主要整理了几种常见的javascipt判断图片加载 ...

  4. js判断数组里是否有重复元素的方法

    转: js判断数组里是否有重复元素的方法 https://blog.csdn.net/longzhoufeng/article/details/78840974 第一种方法:但是下面的这种方法数字字符 ...

  5. PHP简单方法判断文件是否是图片 PHP best way to check if file is an image

    (PHP 4 >= 4.3.0, PHP 5, PHP 7) exif_imagetype - 判断一个图像的类型 图像类型常量 值 常量 1 IMAGETYPE_GIF 2 IMAGETYPE ...

  6. 计算机画图怎样更改文字,如何在图片上改字|超简单的修改图片里文字方法

    这篇文章将要给大家介绍的是,不用联网,不用下载专业的图像处理软件,单纯用画图工具,就能修改表情包.图片上文字的方法,只适合简单的图片处理,复杂的还是交给专业的图像处理工具吧.下面系统吧就给大家带来修改 ...

  7. 在Matlab图片里输入数学公式、符号和希腊字母的方法

    在Matlab图片里输入数学公式.符号和希腊字母的方法 在所有的Matlab Figure里都可以使用大量的Tex代码来输入公式.数学符号等.而且,与Word2007类似,都能够写完立马显示,不对的话 ...

  8. C# 判断图片是CMYK模式还是RGB模式最简单的方法

    C#判断图片是CMYK模式还是RGB模式最简单的方法: Image img = Bitmap.FromFile("图片路径", true); PixelFormat pf = (P ...

  9. 3.js中判断数组中是否存在某个对象/值,判断数组里的对象是否存在某个值 的五种方法 及应用场景|判断数组里有没有某对象,有不添加,没有则添加到数组

    3.js中判断数组中是否存在某个对象/值,判断数组里的对象是否存在某个值 的五种方法 及应用场景 一.当数组中的数据是简单类型时: 应用js中的indexof方法:存在则返回当前项索引,不存在则返回 ...

最新文章

  1. 三维点云分割综述(上)
  2. easyui placeholder 解决方案
  3. [Silverlight]使用MVVM模式打造英汉词典
  4. 直播 | ICML 2021论文解读:具有局部和全局结构的自监督图表征学习
  5. 创业融资十项注意要点
  6. RocketMQ核心概念
  7. 信息学奥赛C++语言:成绩等级
  8. [HDU3037]Saving Beans,插板法+lucas定理
  9. 计算机视觉-混合动态纹理模型(Mixtures of Dynamic Textures)
  10. 记一次失败的电话面试
  11. 安装appach时出现没有安装gcc的错误,用yum安装gcc时yum出现错误(修改yum配置)...
  12. JSF+Spring+Hibernate整合要点
  13. 取色器——TakeColor绿色安全简单
  14. 数据抓取的艺术(一~三):Selenium+Phantomjs数据抓取环境配置
  15. c语言数据域和指针域,C语言的变量域和指针
  16. 算法训练营 图的应用(拓扑排序)
  17. 1和new Number(1)的区别
  18. 计算机等级描述,关于计算机机型描述中,错误的是( )。
  19. 直播 | 如何在顶会夺冠:iWildCam 2020 冠军经验与技巧分享
  20. 原创|我为什么不建议你等公司倒闭后,再找工作!

热门文章

  1. 智能优化算法之遗传算法python实现细节,GA库函数调用方法
  2. 分支与循环语句(下)
  3. 三方接口签名验签简易设计与实现
  4. 服务器操作系统不能显示全屏,服务器窗口显示不全屏
  5. 即时配送行业黑马 闪飞侠2022正式起航
  6. 【NLP】第11章 让你的数据说话:故事、问题和答案
  7. YV12和I420的区别 yuv420和yuv420p的区别
  8. 用户唯一登录,最新登录挤掉以前的登录,实现踢人.
  9. 程序员是如何开灯的 白话闲聊mqtt协议
  10. 并行与并发的区别,一瞬间就能理解并记住