import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
X,y = datasets.make_moons(n_samples=100,noise = 0.3)
plt.scatter(X[y==0,0],X[y==0,1],color = 'r')
plt.scatter(X[y==1,0],X[y==1,1],color = 'b')
plt.show()

数据集分割

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test =  train_test_split(X, y, test_size=0.2, random_state=666,shuffle=True)
# Parameters:
# *arrays :需要进行划分的X ;
# target :数据集的结果
# test_size :测试集占整个数据集的多少比例
# train_size :test_size +train_size = 1
# random_state : 随机种子
# shuffle : 是否洗牌 在进行划分前# 返回 X_train,X_test,y_train,y_test
plt.scatter(X_train[y_train==0,0],X_train[y_train==0,1])
plt.scatter(X_train[y_train==1,0],X_train[y_train==1,1])
plt.scatter(X_test[y_test==0,0],X_test[y_test==0,1])
plt.scatter(X_test[y_test==1,0],X_test[y_test==1,1])
plt.show()

交叉验证

from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5)
cross_val_score(knn,X,y,cv = 3) # cv表示将数据分成几份
array([0.88235294, 0.85294118, 0.875     ])

使用交叉验证获得最优参数

%%time
best_k, best_p, best_score = 0, 0, 0
for k in range(2, 11):for p in range(1, 6):knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)scores = cross_val_score(knn_clf, X_train, y_train)score = np.mean(scores)if score > best_score:best_k, best_p, best_score = k, p, scoreprint("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)
Best K = 8
Best P = 3
Best Score = 0.8874643874643874
Wall time: 236 ms

Grid超参数搜索

# gridSearchCV这个是属于网格搜索超参数,这个类需要我y必须放入的一些参数
# 第一个参数是实例化的模型,这里用的knn模型,
# 第二个参数是我们需要网格搜索的超参数,这里param_grid需要有格式;
# 及param_grid必须是一个列表,这个列表是由多个字典组成,字典中的键就是我们需要放入之前的超参数;
# 假如放入knn模型 ,则键可以是weights,n_neihbors或者是p这三个变量其实也就是KNeighborsClassifier的超参数 。
# 我们也可以倒过来理解,每个字典里面都是一组超参数网格搜索,最后再比较多组最优超参数值,从里面挑出一组最优的超参数
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
param_grid=[{'weights':['uniform'],'n_neighbors':[i for i in range(1,11)]},{'weights':['distance'],'n_neighbors':[i for i in range(1,11)],'p':[i for i in range(1,6)]}
]grid_search = GridSearchCV(KNeighborsClassifier(),param_grid)
%%time
grid_search.fit(X_train, y_train)
Wall time: 375 msGridSearchCV(cv=None, error_score='raise',estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',metric_params=None, n_jobs=1, n_neighbors=5, p=2,weights='uniform'),fit_params=None, iid=True, n_jobs=1,param_grid=[{'weights': ['uniform'], 'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, {'weights': ['distance'], 'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'p': [1, 2, 3, 4, 5]}],pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',scoring=None, verbose=0)
grid_search.best_estimator_
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',metric_params=None, n_jobs=1, n_neighbors=8, p=3,weights='distance')
grid_search.best_score_ # 获得grid_search产生的
0.8875
%%time
#  verbose表示:
# verbose:日志冗长度,int:冗长度,0:不输出训练过程,1:偶尔输出,>1:对每个子模型都输出。
grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, n_jobs=-1, verbose=4)
grid_search.fit(X_train, y_train)
Fitting 3 folds for each of 60 candidates, totalling 180 fits[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:    3.9sWall time: 4.41 s[Parallel(n_jobs=-1)]: Done 180 out of 180 | elapsed:    4.1s finished
grid_search.best_params_  # 获得grid_search产生的最优超参数
{'n_neighbors': 8, 'p': 3, 'weights': 'distance'}
knn_clf = grid_search.best_estimator_  # 获得加入最优超参数后的生成的最优机器学习模型
knn_clf.fit(X_train,y_train)
knn_clf.score(X_test,y_test)
0.9

在网格搜索中增加交叉验证

GridSearchCV(KNeighborsClassifier(), param_grid, n_jobs=-1, verbose=4,cv =4)
GridSearchCV(cv=4, error_score='raise',estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',metric_params=None, n_jobs=1, n_neighbors=5, p=2,weights='uniform'),fit_params=None, iid=True, n_jobs=-1,param_grid=[{'weights': ['uniform'], 'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, {'weights': ['distance'], 'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'p': [1, 2, 3, 4, 5]}],pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',scoring=None, verbose=4)

sklearn中的model_selection相关推荐

  1. sklearn中的xgboost_xgboost来了

    一.xgboost前奏 1,介绍一下啥是xgboost XGBoost全称是eXtreme Gradient Boosting,即极限梯度提升算法.它由陈天奇所设计,致力于让提升树突破自身的计算极限, ...

  2. sklearn中分类器的比较

    简 介: 运行对比了 分类器的比较? 中的sklearn中的分类的性能对比.这为我们理解机器学习中的特性提供了理解基础. 关键词: sklearn,python #mermaid-svg-UbOwlP ...

  3. 导入训练好的决策树文件_决策树在sklearn中的实现

    小伙伴们大家好~o( ̄▽ ̄)ブ,今天做一下如何使用sklearn实现决策树,首先声明一下,我的开发环境是Jupyter lab,所用的库和版本大家参考: Python 3.7.1(你的版本至少要3.4 ...

  4. sklearn中的交叉验证(Cross-Validation)

    sklearn中的交叉验证(Cross-Validation) cross validation大概的意思是:对于原始数据我们要将其一部分分为traindata,一部分分为test data.trai ...

  5. Python之 sklearn:sklearn中的train_test_split函数的简介及使用方法之详细攻略

    Python之 sklearn:sklearn中的train_test_split函数的简介及使用方法之详细攻略 目录 sklearn中的train_test_split函数的简介 train_tes ...

  6. Sklearn中的CV与KFold详解

    关于交叉验证,我在之前的文章中已经进行了简单的介绍,而现在我们则通过几个更加详尽的例子.详细的介绍 CV %matplotlib inline import numpy as np from skle ...

  7. sklearn中的回归决策树

    回归 决策树通过使用 DecisionTreeRegressor 类也可以用来解决回归问题. 如在分类设置中,拟合方法将数组X和数组y作为参数,只有在这种情况下,y数组预期才是浮点值: 下面是简单的使 ...

  8. sklearn中的分类决策树

    决策树 决策树简介 决策树是一种使用if-then-else的决策规则的监督学习方法. 其三要素为,枝节点,叶节点与分支条件,同时为了减少过拟合还有剪枝方法 为了便于记忆,可以称其为一方法三要素 决策 ...

  9. sklearn中的Pipline(流水线学习器)

    简介 管道机制实现了对全部步骤的流式化封装和管理(streaming workflows with pipelines). 管道机制(也有人翻译为流水线学习器?这样翻译可能更有利于后面内容的理解)在机 ...

最新文章

  1. Xcode升级到8之后的一些需要我们手动配置的地方
  2. 【C++ 语言】Visual Studio 配置 FFMPEG 开发环境 ( VS2019 CMake 环境安装 | 下载 FFMPEG 开发包 | 配置 FFMPEG )
  3. 深入理解 Java内存模型
  4. 龙剑服务器为什么总是维修,《龙剑》2014年3月13日更新维护公告
  5. tensorflow实现原理
  6. php利用openssl实现RSA非对称加密签名
  7. stream获取filter
  8. dat文本导入mysql_mysql学习笔记(九) 增删改查的优化
  9. 4-2MapReduce的运行流程
  10. DCL 管理用户 mysql
  11. C语言的那些秘密之---函数返回局部变量(转)
  12. 人脸检测(十)--强分类器源码分析
  13. OpenOffice.org 2.0已经发布了。
  14. C++11 static_assert(转载)
  15. 计算机usb接口无法充电,电脑可充电USB接口不能使用怎么办
  16. 牛客编程巅峰赛S1第2场 - 黄金钻石 1.规律 2.bfs
  17. Excel如何冻结窗口
  18. 调频 调幅 与 通信
  19. 51单片机wifi物联网的浇花控制系统设计
  20. nginx教程(一)--nginx是什么?能干什么?

热门文章

  1. 一号店(1号店)静态网页布局HTML5+CSS
  2. NTP时钟服务器(卫星时钟系统)是如何让集成系统协调工作的?
  3. java实现获取各网站的机票信息_java爬取某个机票查询网站上面的信息(刚学!!!)...
  4. 横河变送器EJA110E
  5. C#中WinForm游戏开发——坦克大战
  6. 在MAC下调试运行暗黑世界客户端及部分代码注解(基于Firefly)
  7. 2015的读书计划和读书心得
  8. 插画师培训怎么选,5大插画师培训班排名
  9. ZCL Cluster Library的理解
  10. 未来几年(定制客运)城际拼车业务会严重影响传统客运