目录

一、安装Julia

二、Flux简介

三、安装Flux和相关依赖库

四、cifar10项目下载

*五、cifar10数据集下载

六、开始训练


一、安装Julia

IDE是Atom,安装和使用教程为:Windows10 Atom安装和运行Julia的使用教程(详细)


二、Flux简介

1.Flux.jl是一个内置于Julia的机器学习框架。它与PyTorch有一些相似之处,就像大多数现代框架一样。

2.Flux是一种优雅的机器学习方法。 它是100%纯Julia堆栈形式,并在Julia的原生GPU和AD支持之上提供轻量级抽象。

3.Flux是一个用于机器学习的库。 它功能强大,具有即插即拔的灵活性,即内置了许多有用的工具,但也可以在需要的地方使用Julia语言的全部功能。

4.Flux遵循以下几个关键原则:

(1) Flux对于正则化或嵌入等功能的显式API相对较少。 相反,写下数学形式将起作用 ,并且速度很快。

(2) 所有的知识和工具,从LSTM到GPU内核,都是简单的Julia代码。 如果有疑问的话,可以查看官方教程。 如果需要不同的函数块或者是功能模块,我们也可以轻松自己动手实现。

(3)Flux适用于Julia库,包括从数据帧和图像到差分方程求解器等等内容,因此我们也可以轻松构建集成Flux模型的复杂数据处理流水线。

5.Flux相关教程链接(FQ):https://fluxml.ai/Flux.jl/stable/

6.Flux模型代码示例链接:https://github.com/FluxML/model-zoo/


三、安装Flux和相关依赖库

1.打开julia控制台,或者打开Atom启动下方REPL的julia,先输入如下指令

using Pkg

2.安装Flux

Pkg.add("Flux")

3.同理,安装依赖项Metalhead

Pkg.add("Metalhead")
Pkg.add("Images")
Pkg.add("Statistics")

一般安装了Metalhead也会自动帮你装上Images和Statistics~


四、cifar10项目下载

1.下载model-zoo文件夹:https://github.com/FluxML/model-zoo/

2.cifar10.jl在model-zoo-master\vision\cifar10中

3.我们在Atom里打开这个项目,如下


*五、cifar10数据集下载

1.github上的model-zoo里cifar10的下载函数里面解压的方式是Linux的,在

C:\Users\你的电脑用户名\.julia\packages\Metalhead\fYeSU\src\datasets\autodetect.jl:

function download(which)if which === ImageNeterror("ImageNet is not automatiacally downloadable. See instructions in datasets/README.md")elseif which == CIFAR10local_path = joinpath(@__DIR__, "..","..",datasets, "cifar-10-binary.tar.gz")#print(local_path)dir_path = joinpath(@__DIR__,"..","..","datasets")if(!isdir(joinpath(dir_path, "cifar-10-batches-bin")))if(!isfile(local_path))Base.download("https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", local_path)endrun(`tar -xzvf $local_path -C $dir_path`)endelseerror("Download not supported for $(which)")end
end

这意味着解压函数在windows10上是无效的,但是这并不影响我们在windows上的使用,我们只需要手动下载即可

2.下载地址:https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

3.下载完成后请放到这个文件夹(其实是放到这里是为了配合Linux操作系统):

C:\Users\你的电脑用户名\.julia\packages\Metalhead\fYeSU\datasets

解压后的内容如下:

 注意:不放这里,你就等着报错报到死吧!!那就是无法找到cifar10数据集位置!!


六、开始训练

1.核心代码

cifar10.jl

using Flux, Metalhead, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Metalhead: trainimgs
using Images: channelview
using Statistics: mean
using Base.Iterators: partition# VGG16 and VGG19 modelsvgg16() = Chain(Conv((3, 3), 3 => 64, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(64),Conv((3, 3), 64 => 64, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(64),x -> maxpool(x, (2, 2)),Conv((3, 3), 64 => 128, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(128),Conv((3, 3), 128 => 128, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(128),x -> maxpool(x, (2,2)),Conv((3, 3), 128 => 256, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(256),Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(256),Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(256),x -> maxpool(x, (2, 2)),Conv((3, 3), 256 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),x -> maxpool(x, (2, 2)),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),x -> maxpool(x, (2, 2)),x -> reshape(x, :, size(x, 4)),Dense(512, 4096, relu),Dropout(0.5),Dense(4096, 4096, relu),Dropout(0.5),Dense(4096, 10),softmax) |> gpuvgg19() = Chain(Conv((3, 3), 3 => 64, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(64),Conv((3, 3), 64 => 64, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(64),x -> maxpool(x, (2, 2)),Conv((3, 3), 64 => 128, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(128),Conv((3, 3), 128 => 128, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(128),x -> maxpool(x, (2, 2)),Conv((3, 3), 128 => 256, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(256),Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(256),Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(256),Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),x -> maxpool(x, (2, 2)),Conv((3, 3), 256 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),x -> maxpool(x, (2, 2)),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),BatchNorm(512),Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),x -> maxpool(x, (2, 2)),x -> reshape(x, :, size(x, 4)),Dense(512, 4096, relu),Dropout(0.5),Dense(4096, 4096, relu),Dropout(0.5),Dense(4096, 10),softmax) |> gpu# Function to convert the RGB image to Float64 Arraysgetarray(X) = Float32.(permutedims(channelview(X), (2, 3, 1)))# Fetching the train and validation data and getting them into proper shapeX = trainimgs(CIFAR10)
imgs = [getarray(X[i].img) for i in 1:50000]
labels = onehotbatch([X[i].ground_truth.class for i in 1:50000],1:10)
train = gpu.([(cat(imgs[i]..., dims = 4), labels[:,i]) for i in partition(1:49000, 100)])
valset = collect(49001:50000)
valX = cat(imgs[valset]..., dims = 4) |> gpu
valY = labels[:, valset] |> gpu# Defining the loss and accuracy functionsm = vgg16()loss(x, y) = crossentropy(m(x), y)accuracy(x, y) = mean(onecold(m(x), 1:10) .== onecold(y, 1:10))# Defining the callback and the optimizerevalcb = throttle(() -> @show(accuracy(valX, valY)), 10)opt = ADAM()# Starting to train modelsFlux.train!(loss, params(m), train, opt, cb = evalcb)# Fetch the test data from Metalhead and get it into proper shape.
# CIFAR-10 does not specify a validation set so valimgs fetch the testdata instead of testimgstest = valimgs(CIFAR10)testimgs = [getarray(test[i].img) for i in 1:10000]
testY = onehotbatch([test[i].ground_truth.class for i in 1:10000], 1:10) |> gpu
testX = cat(testimgs..., dims = 4) |> gpu# Print the final accuracy@show(accuracy(testX, testY))

2.菜单栏里 Packages->Julia->Run File,可以在REPL里看到训练的效果,也就是最后一句代码展示准确度

3.至于如何放到GPU上训练,我们还需要下载CuArrays:

Using Pkg
Pkg.add("CuArrays")

以及安装CUDA和cuDNN支持,具体细节看官方文档:https://fluxml.ai/Flux.jl/stable/gpu/#Installation-

Julia 基于Flux深度学习框架的cifar10数据集分类相关推荐

  1. 基于TensorFlow深度学习框架,运用python搭建LeNet-5卷积神经网络模型和mnist手写数字识别数据集,设计一个手写数字识别软件。

    本软件是基于TensorFlow深度学习框架,运用LeNet-5卷积神经网络模型和mnist手写数字识别数据集所设计的手写数字识别软件. 具体实现如下: 1.读入数据:运用TensorFlow深度学习 ...

  2. 【深度学习】基于PyTorch深度学习框架的序列图像数据装载器

    作者 | Harsh Maheshwari 编译 | VK 来源 | Towards Data Science 如今,深度学习和机器学习算法正在统治世界.PyTorch是最常用的深度学习框架之一,用于 ...

  3. 【开源项目推荐-ColugoMum】这群本科生基于国产深度学习框架PaddlePadddle开源了零售行业解决方案

    零售行业是我国非常重要的行业之一,随着手机支付和购物用户数量的不断提高,以及数字化技术的不断发展,零售行业的企业尤其是线下体验店对数字化转型的意愿不断加强,未来我国智慧零售行业有望持续快速发展. 那么 ...

  4. julia有 pytorch包吗_有了Julia语言,深度学习框架从此不需要计算图

    选自julialang 作者:Mike Innes 等 机器之心编译 参与:刘晓坤.思源 本文基于 NeurIPS MLSys 的一篇论文<Fashionable Modelling with ...

  5. 【深度学习】训练CIFAR-10数据集实现分类加测试

    网上有很多博主写的训练CIFAR-10的代码,本次只是单纯记录一下自己调试的一个程序,对于初学深度学习的小白可以参考,如有不对,请多多见谅!!! 一.CIFAR-10数据集由10个类的60000个32 ...

  6. 基于迁移深度学习的遥感图像场景分类

    前述 根据语义特征对遥感图像场景进行分类是一项具有挑战性的任务.因为遥感图像场景的类内变化较大,而类间变化有时却较小.不同的物体会以不同的尺度和方向出现在同一类场景中,而同样的物体也可能出现在不同的场 ...

  7. 开源基于PyTorch深度学习框架实现图卷积

    开源代码参考:学习与优化 Graph Convolutional Networks paper -> paper link -> github Distilling Knowledge F ...

  8. 基于kera-yolo3深度学习框架的目标检测

    1.在github官网上面下载kera-yolo3框架,网址为:https://github.com/qqwweee/keras-yolo3(注意:github网站不稳定时可到中国网站gitee官网去 ...

  9. JAVA深度学习框架DJL之鞋子分类

    机器学习生命周期 遵循机器学习生命周期来生成鞋类分类模型. ML生命周期不同于传统的软件开发生命周期,它包含六个具体步骤: 获取数据 清理并准备数据 产生模型 评估模型 部署模型 从模型获得预测(或推 ...

最新文章

  1. 从Elasticsearch来看分布式系统架构设计,真是666~
  2. 人脑启发AI设计:让神经网络统一翻译语音和文本
  3. Linux操作系统上lsof命令详解
  4. [密码学基础][每个信息安全博士生应该知道的52件事][Bristol Cryptography][第4篇] P类复杂问题
  5. Spring AOP 功能使用详解
  6. 双Y轴echarts
  7. matlab相位相关图像配准,数字图像处理,相位相关图像配准
  8. PHP最全笔记(三)(值得收藏,不时翻看一下)
  9. php format tool,usb 開機碟製作工具HP USB Disk Storage format Tool 2.23
  10. 【AllenNLP入门教程】: 2、基于Allennlp2.4版本的一些使用技巧
  11. deepstream-test3
  12. 剑三 服务器状态查询,数据互通全面启动_剑侠情缘网络版叁_金山游戏官方网站_金山逍遥Xoyo.com...
  13. 开运魔法,晓腾叔叔的日常迷信。
  14. 2021-2026年中国畜牧业发展环境分析及投资前景预测报告
  15. mysql是如何保证持久性的?
  16. ENVI系列--遥感影像UTM投影计算公式
  17. JavaScript实现数字金额小写转大写
  18. 什么是http接口?
  19. Python实现视频转 gif 动图
  20. mysql中in的参数有限制_数据库 in 可以包含的参数个数

热门文章

  1. vim: 根据编程语言自动选择不同的colorscheme
  2. Android APP开发入门 使用Android Studio环境pdf
  3. 在jOOQ中获取数据的多种不同方式
  4. android 9 以上,使用HTTPclient
  5. 静态代码块是什么?有什么用?
  6. 爆炸式的工作机会和多项目同步
  7. 开通www 国际域名个人网站操作介绍
  8. 给南开大学礼鹤同学的回信----关于开源的思考
  9. 如何在html图片里输入文字居中显示,CSS设置文字图片垂直居中的方法总结
  10. symbol xxx multiply defined