O’Shea基于深度学习调制识别代码
Convolutional Radio Modulation Recognition Networks 论文代码复现
1. 数据集下载
数据集生成源代码(需要GNU Radio来实现,如需安装学习参考主页相关教程)
链接: https://github.com/radioML/dataset
现成的数据集下载(来自DeepSig公司主页)
链接: https://www.deepsig.ai/datasets
一般来说,RADIOML 2016.10A就可以满足需求,其他数据集将消耗更大的硬件需求。
2. 原论文
链接: https://arxiv.org/abs/1602.04105
3. 软硬件约束
1.tensorflow2,keras2
2.gpu:RTX3080Ti,cpu:Intel Xeon E5-2690 v4
4. 代码
(1)导包,本来我是tf2,后来为了适应tf1版本,加了几行代码,现在tf1应该可以用。
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSessionconfig = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)import h5py
import numpy as np
#import theano as th
import os,random
import pickle,random, sys
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
from tensorflow import keras
import matplotlib.pyplot as plt
import gc
from tensorflow.keras.layers import Reshape,Dense,Dropout,Activation,Flatten,Convolution2D, MaxPooling2D, ZeroPadding2D,LSTM%matplotlib inline
#os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["THEANO_FLAGS"] = "device=gpu%d"%(0)
(2)导入数据集,X是数据,lbl是标签(包含调制类型和信噪比)
Xd = pickle.load(open("RML2016.10a_dict.pkl",'rb'),encoding='latin1')
#data=pk.load(f,encoding='latin1')
#print(data)
snrs,mods = map(lambda j: sorted(list(set(map(lambda x: x[j], Xd.keys())))), [1,0])
X = []
lbl = []
for mod in mods:for snr in snrs:X.append(Xd[(mod,snr)])for i in range(Xd[(mod,snr)].shape[0]): lbl.append((mod,snr))
X = np.vstack(X)
(3)划分训练集和数据集,to_onehot打标签
np.random.seed(2016)
n_examples = X.shape[0]
n_train = int(n_examples*0.8)
train_idx = np.random.choice(range(0,n_examples), size=n_train, replace=False)
test_idx = list(set(range(0,n_examples))-set(train_idx))
X_train = X[train_idx]
X_test = X[test_idx]
def to_onehot(yy):yy1 = np.zeros([len(yy), max(yy)+1])yy1[np.arange(len(yy)),yy] = 1return yy1
Y_train = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), train_idx)))
Y_test = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), test_idx)))in_shp = list(X_train.shape[1:])
print (X_train.shape, in_shp)classes = mods
print('数据集总数:',n_examples)
print('调制方式' , len(mods),'种:' ,mods)
print('信噪比:',snrs)
(4)神经网络结构
dr = 0.5 # dropout rate (%)
lstm_output_size=512
model = keras.Sequential()
model.add(Reshape(in_shp+[1], input_shape=in_shp))
model.add(ZeroPadding2D(padding=(0,2)))
model.add(Convolution2D(256, (1,3), padding='valid', activation="relu", name="conv1"))
model.add(Dropout(dr))
model.add(ZeroPadding2D(padding=(0, 2)))
model.add(Convolution2D(80, (2, 3), padding="valid", activation="relu", name="conv2"))
model.add(Dropout(dr))
model.add(Flatten())
model.add(Dense(1024, activation='relu', name="dense1"))
model.add(Dense(256, activation='relu', name="dense1"))
model.add(Dropout(dr))
model.add(Dense( len(classes), name="dense2" ))
model.add(Activation('softmax'))
model.add(Reshape([len(classes)]))
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=["accuracy"])model.summary()
(5)训练,引入了早停法
# Set up some params
epochs = 100 # number of epochs to train on
batch_size = 1024 # training batch size
# - call the main training loop in keras for our network+dataset
filepath = 'conv.h5'
history = model.fit(X_train,Y_train,batch_size=batch_size,epochs=epochs,verbose=2,validation_data=(X_test, Y_test),callbacks = [keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto'),keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto')])
# we re-load the best weights once training is finished
model.load_weights(filepath)
(6)记录并绘制损失曲线
score = model.evaluate(X_test, Y_test, batch_size=batch_size,verbose=0)
print (score)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
(7)绘制混淆矩阵
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=[]):plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(labels))plt.xticks(tick_marks, labels, rotation=45)plt.yticks(tick_marks, labels)plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label'
# Plot confusion matrix
test_Y_hat = model.predict(X_test, batch_size=batch_size)
conf = np.zeros([len(classes),len(classes)])
confnorm = np.zeros([len(classes),len(classes)])
for i in range(0,X_test.shape[0]):j = list(Y_test[i,:]).index(1)k = int(np.argmax(test_Y_hat[i,:]))conf[j,k] = conf[j,k] + 1
for i in range(0,len(classes)):confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])
plot_confusion_matrix(confnorm, labels=classes)# Plot confusion matrix
acc = {}
for snr in snrs:# extract classes @ SNRtest_SNRs = list(map(lambda x: lbl[x][1], test_idx))#map在tensorflow2中前面要加listtest_X_i = X_test[np.where(np.array(test_SNRs)==snr)]test_Y_i = Y_test[np.where(np.array(test_SNRs)==snr)] # estimate classestest_Y_i_hat = model.predict(test_X_i)conf = np.zeros([len(classes),len(classes)])confnorm = np.zeros([len(classes),len(classes)])for i in range(0,test_X_i.shape[0]):j = list(test_Y_i[i,:]).index(1)k = int(np.argmax(test_Y_i_hat[i,:]))conf[j,k] = conf[j,k] + 1for i in range(0,len(classes)):confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])plt.figure()plot_confusion_matrix(confnorm, labels=classes, title="ConvNet Confusion Matrix (SNR=%d)"%(snr))cor = np.sum(np.diag(conf))ncor = np.sum(conf) - corprint ("Overall Accuracy: ", cor / (cor+ncor))acc[snr] = 1.0*cor/(cor+ncor)
(8)绘制信噪比曲线
# Save results to a pickle file for plotting later
print (acc)
fd = open('results_cnn.dat','wb')
pickle.dump( ("CNN2", 0.5, acc) , fd )
# Plot accuracy curve
plt.plot(snrs, list(map(lambda x: acc[x], snrs)))
plt.xlabel('Signal to Noise Ratio')
plt.ylabel('Classification Accuracy')
plt.title("CNN2 Classification Accuracy on RadioML 2016.10 Alpha")
更多调制识别论文的复现论文将陆陆续续发布!
O’Shea基于深度学习调制识别代码相关推荐
- 每周AI应用方案精选:虹膜识别;基于深度学习人脸识别方案等
2019-12-12 17:52:41 每周三期,详解人工智能产业解决方案,让AI离你更近一步. 解决方案均选自机器之心Pro行业数据库. 方案1:虹膜识别解决方案 解决方案简介: 虹膜识别技术是基于 ...
- B站UP搭建世界首个纯红石神经网络、基于深度学习动作识别的色情检测、陈天奇《机器学编译MLC》课程进展、AI前沿论文 | ShowMeAI资讯日报
ShowMeAI日报系列全新升级!覆盖AI人工智能 工具&框架 | 项目&代码 | 博文&分享 | 数据&资源 | 研究&论文 等方向.点击查看 历史文章列表, ...
- 基于深度学习的花卉检测与识别系统(YOLOv5清新界面版,Python代码)
摘要:基于深度学习的花卉检测与识别系统用于常见花卉识别计数,智能检测花卉种类并记录和保存结果,对各种花卉检测结果可视化,更加方便准确辨认花卉.本文详细介绍花卉检测与识别系统,在介绍算法原理的同时,给出 ...
- 基于深度学习的犬种识别软件(YOLOv5清新界面版,Python代码)
摘要:基于深度学习的犬种识别软件用于识别常见多个犬品种,基于YOLOv5算法检测犬种,并通过界面显示记录和管理,智能辅助人们辨别犬种.本文详细介绍博主自主开发的犬种检测系统,在介绍算法原理的同时,给出 ...
- 基于深度学习识别模型的缺陷检测
根据读者反映,咱们的这个PCB素材设置的不对,应该是没有漆,铜线等等,应该是黑白的.额,具体我也知道,但没去过工厂,实在很难获得这些素材... 所以就当是一次瑕疵识别的实践,具体的数据集你可以自己定义 ...
- 毕业设计 基于深度学习的动物识别 - 卷积神经网络 机器视觉 图像识别
文章目录 0 前言 1 背景 2 算法原理 2.1 动物识别方法概况 2.2 常用的网络模型 2.2.1 B-CNN 2.2.2 SSD 3 SSD动物目标检测流程 4 实现效果 5 部分相关代码 5 ...
- 基于深度学习的农作物叶片病害检测系统(UI界面+YOLOv5+训练数据集)
摘要:农作物叶片病害检测系统用于智能检测常见农作物叶片病害情况,自动化标注.记录和保存病害位置和类型,辅助作物病害防治以增加产值.本文详细介绍基于YOLOv5深度学习模型的农作物叶片病害检测系统,在介 ...
- 基于深度学习的智能PCB板缺陷检测系统(Python+清新界面+数据集)
摘要:智能PCB板缺陷检测系统用于智能检测工业印刷电路板(PCB)常见缺陷,自动化标注.记录和保存缺陷位置和类型,以辅助电路板的质检.本文详细介绍智能PCB板缺陷检测系统,在介绍算法原理的同时,给出P ...
- 深度学习在恶意代码检测方面的应用简单调研
随着互联网的繁荣,现阶段的恶意代码也呈现出快速发展的趋势,主要表现为变种数量多.传播速度快.影响范围广.在这样的形势下,传统的恶意代码检测方法已经无法满足人们对恶意代码检测的要求.比如基于签名特征码的 ...
最新文章
- 手机做服务器性能咋样,服务器性能不足 怎样才能逼出最强状态
- Android 常用数据操作封装类案例
- 115网盘 最好的网盘 雨林木风出品 强烈推荐
- ASP.NET读取XML文件
- 输电线缺陷检测 计算机工程与设计,工业CT检测与计算机深度学习
- ECharts快速上手 入门教学
- Pygame——创建游戏地图
- 【网络安全】OSSIM平台网络日志关联分析实战
- 苹果手机设置邮箱服务器端口设置,苹果手机邮箱怎样设置
- 微信鉴权服务器地址,微信开发之微信授权登录
- Word2013使用 插入题注的方式为word自带编辑器编辑的公式进行编号
- 帝国CMS7.5基于es(Elasticsearch)7.x的全文搜索插件
- linux系统路由器地址查询,如何在任何平台上查找路由器的IP地址
- SCOI2016酱油记
- 机器人教育的中心地段
- 10月,你知道有哪些程序员热点新书上榜了吗?
- 通过金矿模型介绍动态规划
- django基础到高手知识笔记总结,50页笔记,共10大模块(第一期).md
- ddos打高防服务器_高防服务器防御DDOS***、CC***方法?
- matlab求球心坐标,已知四顶点坐标求四面体外接球球心坐标
热门文章
- 传智播客成都中心“基础加强班”优惠活动最后一期,立马围观。
- 2017年8月23日 星期三 --出埃及记 Exodus 29:2
- The requested URL was not found on the server. If you entered the URL manually please check your spe
- typescript 八叉树的简单实现
- QuantLib 金融计算——基本组件之天数计算规则详解
- 解决Incorrect result size: expected 1, actual 0!
- PF_RING 6.0.2在Redhat 6.3 x86_64上编译和安装
- IE浏览器图片不显示,报DOM7009: 无法解码 URL 处的图像问题的解决方法
- macOS High Sierra 10.13正式版USB安装盘制作
- 模拟投票小程序C语言代码,微信小程序投票系统创建投票发布demo完整源码下载 一个很简单 - 下载 - 搜珍网...