Convolutional Radio Modulation Recognition Networks 论文代码复现

1. 数据集下载

数据集生成源代码(需要GNU Radio来实现,如需安装学习参考主页相关教程)
链接: https://github.com/radioML/dataset
现成的数据集下载(来自DeepSig公司主页)
链接: https://www.deepsig.ai/datasets
一般来说,RADIOML 2016.10A就可以满足需求,其他数据集将消耗更大的硬件需求。

2. 原论文

链接: https://arxiv.org/abs/1602.04105

3. 软硬件约束

1.tensorflow2,keras2
2.gpu:RTX3080Ti,cpu:Intel Xeon E5-2690 v4

4. 代码

(1)导包,本来我是tf2,后来为了适应tf1版本,加了几行代码,现在tf1应该可以用。

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSessionconfig = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)import h5py
import numpy as np
#import theano as th
import os,random
import pickle,random, sys
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
from tensorflow import keras
import matplotlib.pyplot as plt
import gc
from tensorflow.keras.layers import Reshape,Dense,Dropout,Activation,Flatten,Convolution2D, MaxPooling2D, ZeroPadding2D,LSTM%matplotlib inline
#os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["THEANO_FLAGS"]  = "device=gpu%d"%(0)

(2)导入数据集,X是数据,lbl是标签(包含调制类型和信噪比)

Xd = pickle.load(open("RML2016.10a_dict.pkl",'rb'),encoding='latin1')
#data=pk.load(f,encoding='latin1')
#print(data)
snrs,mods = map(lambda j: sorted(list(set(map(lambda x: x[j], Xd.keys())))), [1,0])
X = []
lbl = []
for mod in mods:for snr in snrs:X.append(Xd[(mod,snr)])for i in range(Xd[(mod,snr)].shape[0]):  lbl.append((mod,snr))
X = np.vstack(X)

(3)划分训练集和数据集,to_onehot打标签

np.random.seed(2016)
n_examples = X.shape[0]
n_train = int(n_examples*0.8)
train_idx = np.random.choice(range(0,n_examples), size=n_train, replace=False)
test_idx = list(set(range(0,n_examples))-set(train_idx))
X_train = X[train_idx]
X_test =  X[test_idx]
def to_onehot(yy):yy1 = np.zeros([len(yy), max(yy)+1])yy1[np.arange(len(yy)),yy] = 1return yy1
Y_train = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), train_idx)))
Y_test = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), test_idx)))in_shp = list(X_train.shape[1:])
print (X_train.shape, in_shp)classes = mods
print('数据集总数:',n_examples)
print('调制方式' , len(mods),'种:' ,mods)
print('信噪比:',snrs)

(4)神经网络结构

dr = 0.5 # dropout rate (%)
lstm_output_size=512
model = keras.Sequential()
model.add(Reshape(in_shp+[1], input_shape=in_shp))
model.add(ZeroPadding2D(padding=(0,2)))
model.add(Convolution2D(256, (1,3),  padding='valid', activation="relu", name="conv1"))
model.add(Dropout(dr))
model.add(ZeroPadding2D(padding=(0, 2)))
model.add(Convolution2D(80, (2, 3), padding="valid", activation="relu", name="conv2"))
model.add(Dropout(dr))
model.add(Flatten())
model.add(Dense(1024, activation='relu', name="dense1"))
model.add(Dense(256, activation='relu', name="dense1"))
model.add(Dropout(dr))
model.add(Dense( len(classes), name="dense2" ))
model.add(Activation('softmax'))
model.add(Reshape([len(classes)]))
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=["accuracy"])model.summary()

(5)训练,引入了早停法

# Set up some params
epochs = 100     # number of epochs to train on
batch_size = 1024  # training batch size
#   - call the main training loop in keras for our network+dataset
filepath = 'conv.h5'
history = model.fit(X_train,Y_train,batch_size=batch_size,epochs=epochs,verbose=2,validation_data=(X_test, Y_test),callbacks = [keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto'),keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto')])
# we re-load the best weights once training is finished
model.load_weights(filepath)

(6)记录并绘制损失曲线

score = model.evaluate(X_test, Y_test, batch_size=batch_size,verbose=0)
print (score)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

(7)绘制混淆矩阵

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=[]):plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(labels))plt.xticks(tick_marks, labels, rotation=45)plt.yticks(tick_marks, labels)plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label'
# Plot confusion matrix
test_Y_hat = model.predict(X_test, batch_size=batch_size)
conf = np.zeros([len(classes),len(classes)])
confnorm = np.zeros([len(classes),len(classes)])
for i in range(0,X_test.shape[0]):j = list(Y_test[i,:]).index(1)k = int(np.argmax(test_Y_hat[i,:]))conf[j,k] = conf[j,k] + 1
for i in range(0,len(classes)):confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])
plot_confusion_matrix(confnorm, labels=classes)# Plot confusion matrix
acc = {}
for snr in snrs:# extract classes @ SNRtest_SNRs = list(map(lambda x: lbl[x][1], test_idx))#map在tensorflow2中前面要加listtest_X_i = X_test[np.where(np.array(test_SNRs)==snr)]test_Y_i = Y_test[np.where(np.array(test_SNRs)==snr)]    # estimate classestest_Y_i_hat = model.predict(test_X_i)conf = np.zeros([len(classes),len(classes)])confnorm = np.zeros([len(classes),len(classes)])for i in range(0,test_X_i.shape[0]):j = list(test_Y_i[i,:]).index(1)k = int(np.argmax(test_Y_i_hat[i,:]))conf[j,k] = conf[j,k] + 1for i in range(0,len(classes)):confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])plt.figure()plot_confusion_matrix(confnorm, labels=classes, title="ConvNet Confusion Matrix (SNR=%d)"%(snr))cor = np.sum(np.diag(conf))ncor = np.sum(conf) - corprint ("Overall Accuracy: ", cor / (cor+ncor))acc[snr] = 1.0*cor/(cor+ncor)

(8)绘制信噪比曲线

# Save results to a pickle file for plotting later
print (acc)
fd = open('results_cnn.dat','wb')
pickle.dump( ("CNN2", 0.5, acc) , fd )
# Plot accuracy curve
plt.plot(snrs, list(map(lambda x: acc[x], snrs)))
plt.xlabel('Signal to Noise Ratio')
plt.ylabel('Classification Accuracy')
plt.title("CNN2 Classification Accuracy on RadioML 2016.10 Alpha")

更多调制识别论文的复现论文将陆陆续续发布!

O’Shea基于深度学习调制识别代码相关推荐

  1. 每周AI应用方案精选:虹膜识别;基于深度学习人脸识别方案等

    2019-12-12 17:52:41 每周三期,详解人工智能产业解决方案,让AI离你更近一步. 解决方案均选自机器之心Pro行业数据库. 方案1:虹膜识别解决方案 解决方案简介: 虹膜识别技术是基于 ...

  2. B站UP搭建世界首个纯红石神经网络、基于深度学习动作识别的色情检测、陈天奇《机器学编译MLC》课程进展、AI前沿论文 | ShowMeAI资讯日报

    ShowMeAI日报系列全新升级!覆盖AI人工智能 工具&框架 | 项目&代码 | 博文&分享 | 数据&资源 | 研究&论文 等方向.点击查看 历史文章列表, ...

  3. 基于深度学习的花卉检测与识别系统(YOLOv5清新界面版,Python代码)

    摘要:基于深度学习的花卉检测与识别系统用于常见花卉识别计数,智能检测花卉种类并记录和保存结果,对各种花卉检测结果可视化,更加方便准确辨认花卉.本文详细介绍花卉检测与识别系统,在介绍算法原理的同时,给出 ...

  4. 基于深度学习的犬种识别软件(YOLOv5清新界面版,Python代码)

    摘要:基于深度学习的犬种识别软件用于识别常见多个犬品种,基于YOLOv5算法检测犬种,并通过界面显示记录和管理,智能辅助人们辨别犬种.本文详细介绍博主自主开发的犬种检测系统,在介绍算法原理的同时,给出 ...

  5. 基于深度学习识别模型的缺陷检测

    根据读者反映,咱们的这个PCB素材设置的不对,应该是没有漆,铜线等等,应该是黑白的.额,具体我也知道,但没去过工厂,实在很难获得这些素材... 所以就当是一次瑕疵识别的实践,具体的数据集你可以自己定义 ...

  6. 毕业设计 基于深度学习的动物识别 - 卷积神经网络 机器视觉 图像识别

    文章目录 0 前言 1 背景 2 算法原理 2.1 动物识别方法概况 2.2 常用的网络模型 2.2.1 B-CNN 2.2.2 SSD 3 SSD动物目标检测流程 4 实现效果 5 部分相关代码 5 ...

  7. 基于深度学习的农作物叶片病害检测系统(UI界面+YOLOv5+训练数据集)

    摘要:农作物叶片病害检测系统用于智能检测常见农作物叶片病害情况,自动化标注.记录和保存病害位置和类型,辅助作物病害防治以增加产值.本文详细介绍基于YOLOv5深度学习模型的农作物叶片病害检测系统,在介 ...

  8. 基于深度学习的智能PCB板缺陷检测系统(Python+清新界面+数据集)

    摘要:智能PCB板缺陷检测系统用于智能检测工业印刷电路板(PCB)常见缺陷,自动化标注.记录和保存缺陷位置和类型,以辅助电路板的质检.本文详细介绍智能PCB板缺陷检测系统,在介绍算法原理的同时,给出P ...

  9. 深度学习在恶意代码检测方面的应用简单调研

    随着互联网的繁荣,现阶段的恶意代码也呈现出快速发展的趋势,主要表现为变种数量多.传播速度快.影响范围广.在这样的形势下,传统的恶意代码检测方法已经无法满足人们对恶意代码检测的要求.比如基于签名特征码的 ...

最新文章

  1. 手机做服务器性能咋样,服务器性能不足 怎样才能逼出最强状态
  2. Android 常用数据操作封装类案例
  3. 115网盘 最好的网盘 雨林木风出品 强烈推荐
  4. ASP.NET读取XML文件
  5. 输电线缺陷检测 计算机工程与设计,工业CT检测与计算机深度学习
  6. ECharts快速上手 入门教学
  7. Pygame——创建游戏地图
  8. 【网络安全】OSSIM平台网络日志关联分析实战
  9. 苹果手机设置邮箱服务器端口设置,苹果手机邮箱怎样设置
  10. 微信鉴权服务器地址,微信开发之微信授权登录
  11. Word2013使用 插入题注的方式为word自带编辑器编辑的公式进行编号
  12. 帝国CMS7.5基于es(Elasticsearch)7.x的全文搜索插件
  13. linux系统路由器地址查询,如何在任何平台上查找路由器的IP地址
  14. SCOI2016酱油记
  15. 机器人教育的中心地段
  16. 10月,你知道有哪些程序员热点新书上榜了吗?
  17. 通过金矿模型介绍动态规划
  18. django基础到高手知识笔记总结,50页笔记,共10大模块(第一期).md
  19. ddos打高防服务器_高防服务器防御DDOS***、CC***方法?
  20. matlab求球心坐标,已知四顶点坐标求四面体外接球球心坐标

热门文章

  1. 传智播客成都中心“基础加强班”优惠活动最后一期,立马围观。
  2. 2017年8月23日 星期三 --出埃及记 Exodus 29:2
  3. The requested URL was not found on the server. If you entered the URL manually please check your spe
  4. typescript 八叉树的简单实现
  5. QuantLib 金融计算——基本组件之天数计算规则详解
  6. 解决Incorrect result size: expected 1, actual 0!
  7. PF_RING 6.0.2在Redhat 6.3 x86_64上编译和安装
  8. IE浏览器图片不显示,报DOM7009: 无法解码 URL 处的图像问题的解决方法
  9. macOS High Sierra 10.13正式版USB安装盘制作
  10. 模拟投票小程序C语言代码,微信小程序投票系统创建投票发布demo完整源码下载 一个很简单 - 下载 - 搜珍网...