返回主目录

返回集成学习目录

上一章:机器篇——集成学习(七) 细说 XGBoost 算法

下一章:机器篇——集成学习(九) 细说 hotel_pred 项目(酒店预测)

本小节,细说 ball49_pred 项目,下一小节细说 hotel_pred 项目

三. 项目解说

8. ball49_pred 项目

本章节项目 github 地址:ball49_pred

终于熬到了项目讲解了。本章节只做知识学术研究探讨,研究探讨外的责任概不负责。做其他用途的读者,请自己对自己的行为负责。

在现实中,我们拿到的数据,往往没有那么理想,直接丢进模型,就可以进行训练的。就拿本节内容将要用到的数据来说吧,本节是,假设,有一种彩票,它一共有 49 个号码,每次开奖时,会开 6 个普通号码,1 个特别号码,一共开 7 个号码。基本上是每周开 3 次奖。如 2019 年 前 5 期 的开奖号码如下:

2019001  25.23.45.30.08.32+24
2019002 23.13.19.45.16.12+02
2019003 40.04.23.05.12.30+10
2019004 07.31.28.06.36.43+23
2019005 35.48.18.37.49.10+27

在上面这个开奖号码中,空格键前面的,是开奖期号数;空格键后面,是开奖号码,前面 6 个号码是普通码,+ (加号) 后面的是特别码(如 2019001    25.23.45.30.08.32+24:2019001 表示 2019年第一期;普通号码为 25.23.45.30.08.32;特别号码为 24)。面对这样的数据,我们要怎么去预测下一期的开奖号码呢?

想要预测准确,不管是对于机器学习,还是深度学习,都需要海量的数据。当时,从前面的信息,我们知道,这个将只是每周开 3 次奖,一年下来,也就 140+ 多数据,10年也就1400+条数据,也远远达不到海量的要求。而且,时间段越久,数据的规律也许就越复杂,越复杂的数据,越不容易进行预测找规律。为什么说时间跨越越久,数据规律越复杂呢?打个比方,我们只是拿到开奖的号码,但是,这个号码是用什么算法(姑且把它当做算法)得到的呢?我们对此一无所知,而时间跨度大,则容易出现,今年可能用 A 算法得到开奖号码,明年可能用 B 算法得到开奖号码,后年则可能用 C 算法得到的开奖号码。你试想想,如果让你用固定的数据,去拟合一个算法,效果也许会不错;但是,如果用固定的数据,去拟合 N 个算法,那就头疼了。尽管这只是可能。。。

但是,我们又想要进行预测,那,只能曲线救国了。这里的曲线救国,是因为数据量不够的情况下,我们只对下期进行一个单双,或大小的预测。因为只是进行单双或大小的预测,我们就要把利益最大化,而预测 特别码的单双或大小,可以获得这种有限数据情况下,利益的最大化。

同时,我们先只用最近一年的数据进行训练和预测,根据效果,来进行增加或删减数据。比如,先用今年的数据进行训练和预测,然后,再用今年和去年的数据进行训练和预测,对比两者的模型效果,选择较好的模型。这,就类似于网格搜索方式。以此,来控制数据的选用。

下面,我就以 2019 年的 140+条数据作为例子,做一个特别码单双的预测。

2019年的数据如下:train_ori_data.txt

2019001  25.23.45.30.08.32+24
2019002 23.13.19.45.16.12+02
2019003 40.04.23.05.12.30+10
2019004 07.31.28.06.36.43+23
2019005 35.48.18.37.49.10+27
2019006 44.33.16.17.10.48+07
2019007 16.18.28.35.41.36+37
2019008 34.47.24.13.31.44+08
2019009 30.33.03.19.25.48+44
2019010 46.41.18.33.43.11+36
2019011 46.14.45.41.27.29+21
2019012 14.45.03.10.24.32+29
2019013 22.26.31.33.41.18+27
2019014 05.35.13.06.34.19+39
2019015 15.19.04.29.02.11+01
2019016 01.34.39.49.19.29+43
2019017 40.09.26.17.32.31+27
2019018 46.05.43.08.20.16+11
2019019 03.08.17.19.07.23+12
2019020 33.16.08.13.05.02+41
2019021 39.36.16.49.29.42+34
2019022 14.49.15.33.13.21+01
2019023 19.12.13.09.21.46+15
2019024 43.30.24.25.22.32+11
2019025 24.44.26.47.35.21+08
2019026 16.49.34.32.28.10+45
2019027 15.08.17.23.39.44+40
2019028 13.32.46.34.41.35+30
2019029 33.38.47.27.06.21+46
2019030 28.40.10.04.43.25+39
2019031 09.12.31.23.05.43+29
2019032 40.38.18.46.23.04+15
2019033 12.32.37.45.07.22+47
2019034 15.39.10.11.49.46+21
2019035 21.03.02.31.10.35+11
2019036 44.41.05.03.23.17+40
2019037 37.39.49.15.17.18+31
2019038 15.21.06.24.28.10+49
2019039 42.44.29.34.38.47+36
2019040 41.37.21.45.34.08+49
2019041 14.32.43.17.23.03+36
2019042 06.38.32.18.41.45+33
2019043 40.22.15.35.49.03+39
2019044 30.11.20.32.43.33+34
2019045 37.32.30.47.39.28+23
2019046 21.11.28.34.18.13+19
2019047 21.08.14.02.35.32+04
2019048 45.49.43.29.01.32+05
2019049 19.06.23.01.45.47+49
2019050 24.19.13.35.06.20+43
2019051 44.12.39.06.35.42+22
2019052 29.05.04.32.24.03+02
2019053 12.30.41.10.14.37+05
2019054 38.37.04.07.23.01+34
2019055 09.41.02.01.16.30+35
2019056 04.07.21.35.41.24+18
2019057 43.13.11.17.33.14+27
2019058 14.30.40.48.23.12+01
2019059 37.49.42.29.45.22+47
2019060 16.15.31.36.46.12+03
2019061 19.03.13.43.35.23+26
2019062 16.24.21.28.37.40+48
2019063 21.24.47.06.32.13+38
2019064 24.03.27.49.28.05+44
2019065 25.19.08.37.20.21+13
2019066 03.31.33.21.38.06+45
2019067 42.30.17.44.08.07+24
2019068 48.23.29.35.26.18+36
2019069 15.25.11.09.33.21+46
2019070 29.33.13.30.47.43+16
2019071 29.46.25.32.40.23+13
2019072 20.46.47.12.42.33+27
2019073 29.13.46.09.31.22+45
2019074 45.14.15.42.13.48+02
2019075 44.15.01.24.02.12+14
2019076 42.11.28.08.10.07+29
2019077 05.31.42.20.39.03+25
2019078 49.03.26.06.32.27+29
2019079 32.14.12.46.38.35+16
2019080 28.06.16.33.37.45+13
2019081 17.02.04.11.08.31+19
2019082 47.26.27.44.33.30+37
2019083 25.38.42.19.39.14+46
2019084 16.10.07.18.35.11+43
2019085 14.45.04.24.39.19+08
2019086 03.22.44.19.16.41+29
2019087 21.20.13.43.49.18+30
2019088 24.28.43.27.23.37+16
2019089 06.21.37.22.15.05+36
2019090 19.30.08.44.25.09+37
2019091 48.49.44.22.15.42+06
2019092 29.14.30.16.10.43+11
2019093 13.02.38.23.35.46+17
2019094 32.44.04.09.26.20+11
2019095 12.32.23.35.16.15+02
2019096 06.42.01.05.04.43+08
2019097 01.31.47.21.49.10+13
2019098 06.29.18.33.43.46+16
2019099 38.26.24.44.08.37+18
2019100 14.32.44.06.24.01+05
2019101 02.27.24.39.08.17+10
2019102 04.44.29.03.18.41+39
2019103 08.43.07.25.45.12+13
2019104 43.49.12.28.13.11+47
2019105 33.25.24.23.05.31+07
2019106 28.07.19.04.12.01+43
2019107 12.26.20.30.22.36+02
2019108 44.08.17.42.39.34+07
2019109 42.28.39.44.03.40+37
2019110 06.44.17.49.46.45+41
2019111 03.27.20.46.11.02+04
2019112 48.07.02.18.10.03+33
2019113 12.19.13.15.45.01+44
2019114 23.36.43.35.47.13+16
2019115 10.37.27.42.29.16+44
2019116 19.18.45.21.14.47+49
2019117 22.06.44.19.42.01+25
2019118 31.11.01.41.47.06+32
2019119 10.16.23.27.07.34+24
2019120 49.24.17.30.40.01+48
2019121 20.19.13.02.10.24+17
2019122 13.25.15.46.30.40+26
2019123 33.34.20.30.07.44+46
2019124 18.06.23.46.31.17+44
2019125 03.13.31.49.02.18+14
2019126 23.47.28.44.39.10+13
2019127 01.24.13.12.42.04+48
2019128 03.05.10.35.38.04+07
2019129 46.28.20.35.01.32+15
2019130 42.25.45.15.05.09+38
2019131 22.23.36.41.29.44+02
2019132 10.40.01.49.24.18+07
2019133 15.23.45.01.47.29+06
2019134 02.23.31.16.06.08+01
2019135 47.25.30.26.14.32+04
2019136 11.35.28.06.12.05+17
2019137 32.39.49.21.35.06+15
2019138 02.08.28.04.04.29+05
2019139 41.34.30.23.20.03+47
2019140 35.31.22.24.39.28+12

预测数据  predict_data.txt  内容如下:

2019140  35.31.22.24.39.28+12
2019141 11.24.26.22.41.42+02
2019142 23.11.30.14.32.05+20
2019143 27.25.36.31.22.01+47
2019144 28.10.18.20.05.33+17
2020001 37.29.15.34.30.17+40
2020002 27.24.37.25.30.01+08
2020003 17.42.37.35.49.05+39
2020004 02.30.24.46.20.32+01

项目层级结构:

info 和 models 内的文件,是用本章节的代码,直接生成的。

多跑几次,也许就能跑出一个较好的模型来用。或者,你也可以增加数据,进行训练,得到更好的模型。dataset 里面,还有 2017 和 2018 年的数据,感兴趣的朋友,也可以自己去整理来训练,得到新的模型,用来预测。

环境依赖:

pip install numpy==1.16
conda install scikit-learn
pip install xgboost
pip install joblib
pip install easydict
pip install matplotlib

README.md 文件如下:

# ball49_pred
XGBoost 彩票预测 2020-1-18
- 项目自带少量数据,如果想要更多数据,则自己想办法去搞了。
- 运行prepare.py将数据集划为训练集,验证集和测试集## 训练模型
- 开始训练模型之前,先进行调参
- 运行 ball49_xgboost_train.py 中的下列方法
- demo.best_estimators_depth() # 调节模型的迭代次数和深度
- demo.best_lr_gamma()  # 调节模型的学习率和gamma值
- demo.best_subsmaple_bytree()  # 调节子样本集和叶子数(即调节行和列)
- demo.best_nthread_weight()    # 调节最小子权重和线程数
- demo.best_seek()  # 调节随机种子- 将调参寻找到的较优的参数,配置到 demo.best_param_xgboost() 方法中,开始训练
- 训练模型的时候,会绘制 ROC 曲线,方面目测效果的好坏。
- 如果不想展示,可以在 config.py 中将 __C.TRAIN.ROC_FLAG 设置为 False
- 多训练几次,寻找一个较好的模型。## 预测
- 加载权重,将训练好的权重 .m 文件放入models文件夹
- 运行 ball49_xgboost_test.py,对数据集进行预测输出,预测结果打印在控制栏上。
- 预测效果如下:
- 第 2019141 期 预测结果为 (双), 真实结果为 (02-双)
- 第 2019142 期 预测结果为 (双), 真实结果为 (20-双)
- 第 2019143 期 预测结果为 (单), 真实结果为 (47-单)
- 第 2019144 期 预测结果为 (单), 真实结果为 (17-单)
- 第 2020001 期 预测结果为 (双), 真实结果为 (40-双)
- 第 2020002 期 预测结果为 (单), 真实结果为 (08-双)
- 第 2020003 期 预测结果为 (单), 真实结果为 (39-单)
- 第 2020004 期 预测结果为 (单), 真实结果为 (01-单)
- 第 2020005 期 预测结果为 (双), 真实结果为 (坐等开奖收钱!...... )

1. 第一步,数据预处理,得到可以用来进行训练和预测的数据。

配置文件:config.py

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time     : 2020/01/15 22:41
# @Author   : WanDaoYi
# @FileName : config.py
# ============================================from easydict import EasyDict as edict
import os__C = edict()cfg = __C# common options 公共配置文件
__C.COMMON = edict()
# windows 获取文件绝对路径, 方便 windows 在黑窗口 运行项目
__C.COMMON.BASE_PATH = os.path.abspath(os.path.dirname(__file__))
# # 获取当前窗口的路径, 当用 Linux 的时候切用这个,不然会报错。(windows也可以用这个)
# __C.COMMON.BASE_PATH = os.getcwd()# 训练集,验证集,测试集占的百分比
__C.COMMON.TRAIN_PERCENT = 0.9
__C.COMMON.VAL_PERCENT = 0.1# 模型训练配置文件
__C.TRAIN = edict()# 是否绘制 ROC 曲线,绘制为 True
__C.TRAIN.ROC_FLAG = True# 数据路径
__C.TRAIN.DATA_PATH = os.path.join(__C.COMMON.BASE_PATH, "dataset/train_ori_data.txt")
# 将数据转为目标数据的路径
__C.TRAIN.TRAIN_DATA_INFO_PATH = os.path.join(__C.COMMON.BASE_PATH, "info/train_data.txt")
__C.TRAIN.VAL_DATA_INFO_PATH = os.path.join(__C.COMMON.BASE_PATH, "info/val_data.txt")# 模型保存路径
__C.TRAIN.MODEL_SAVE_PATH = os.path.join(__C.COMMON.BASE_PATH, "models/model_")# 模型预测配置文件
__C.TEST = edict()__C.TEST.DATA_PATH = os.path.join(__C.COMMON.BASE_PATH, "dataset/predict_data.txt")
__C.TEST.TEST_DATA_INFO_PATH = os.path.join(__C.COMMON.BASE_PATH, "info/test_data.txt")# 使用 acc 高的模型,当模型 acc 大于 0.5 时,用 True,否则用 False
__C.TEST.ACC_FLAG = False# 模型路径
__C.TEST.MODEL_PATH = os.path.join(__C.COMMON.BASE_PATH, "models/model_acc=0.200000.m")
# __C.TEST.MODEL_PATH = os.path.join(__C.COMMON.BASE_PATH, "models/model_acc=0.785714.m")

数据预处理文件:prepare.py

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time     : 2020/01/15 22:43
# @Author   : WanDaoYi
# @FileName : prepare.py
# ============================================import numpy as np
import random
from config import cfgclass DataPrepare(object):def __init__(self):self.data_path = cfg.TRAIN.DATA_PATHself.test_path = cfg.TEST.DATA_PATHself.train_data_info_path = cfg.TRAIN.TRAIN_DATA_INFO_PATHself.val_data_info_path = cfg.TRAIN.VAL_DATA_INFO_PATHself.test_data_info_path = cfg.TEST.TEST_DATA_INFO_PATHself.train_percent = cfg.COMMON.TRAIN_PERCENTpass# 读取数据def read_data(self, data_path):""":param data_path: 读取文件的路径:return:"""time_info = []concate_info = []# 去掉数据中的 空格符 制表符 .  + 等符合。with open(data_path, "r") as file:txt_info = file.readlines()for data in txt_info:data_info = data.strip()data_num = data_info.split("\t")time_num = data_num[0]time_info.append(time_num)num_info = data_num[-1]num_list = num_info.split(".")last_num = num_list[-1]split_num = last_num.split("+")concate_num = num_list[: -1] + split_numconcate_info.append(concate_num)pass# 判断彩票周期是否倒序,如果倒序则要处理为顺序。后面设置 label 需要用到time_first = int(time_info[0])time_last = int(time_info[-1])if time_first > time_last:time_info.reverse()concate_info.reverse()pass# 将分开的数据整合data_info = np.c_[time_info, concate_info]return data_info# 数据划分方法def data_split(self, data_list, split_percent):""":param data_list: 要划分的list:param split_percent: 划分的百分比:return:"""# 数据的长度data_len = len(data_list)# 划分的长度n_split = int(split_percent * data_len)i = 0n_split_index_list = []while True:random_num = random.randint(0, data_len)if random_num in n_split_index_list:continuen_split_index_list.append(random_num)i += 1if i == n_split:breakpassn_split_data_list = []leave_data_list = []for index_number, value_info in enumerate(data_list):if index_number in n_split_index_list:n_split_data_list.append(value_info)else:leave_data_list.append(value_info)passpassreturn n_split_data_list, leave_data_listdef add_label(self, data_list):# 设置 label 值。# 将下一期开的特别码 的 单双,设置为 这期的label 值。双为 0,单为 1label_list = []for number in data_list:obj_number = int(number[-1])if obj_number % 2 == 0:label_list.append("0")else:label_list.append("1")passobj_data = np.c_[data_list[: -1], label_list[1:]]return obj_datapass# 保存数据def data_save(self, data_path, data_list):""":param data_path: 保存路径:param data_list: 需要保存的 python 原生 list:return:"""# 将目标数据进行保存,每个数值,以 . 号 隔开data_file = open(data_path, "w")for data in data_list:data_file.write(".".join([info for info in data]))data_file.write("\n")passdata_file.close()pass# 数据拼接保存def generate_data(self):train_info = self.read_data(self.data_path)test_info = self.read_data(self.test_path)# 下面的方法,可以将 list 中内容的类型之间转为 int 类型# data_info = np.array(data_info, dtype=int)train_val_data = self.add_label(train_info)# 转为python 原生的list,下面 write 方法需要原生的 list,numpy 的list无法保存train_val_data = train_val_data.tolist()test_data_list = test_info.tolist()train_data_list, val_data_list = self.data_split(train_val_data, self.train_percent)self.data_save(self.train_data_info_path, train_data_list)self.data_save(self.val_data_info_path, val_data_list)self.data_save(self.test_data_info_path, test_data_list)print("data prepare already!")passif __name__ == "__main__":demo = DataPrepare()demo.generate_data()pass

prepare.py 代码得到的数据有 train_data.txt,val_data.txt,test_data.txt,其结构如下所示:

2019001.25.23.45.30.08.32.24.0
2019002.23.13.19.45.16.12.02.0
2019003.40.04.23.05.12.30.10.1
2019004.07.31.28.06.36.43.23.1
2019005.35.48.18.37.49.10.27.1

数据的第一列表示 开奖期数,最后一列表示 下一期 特别码的单双号,0为双,1为单。中间的 7 列表示开奖号码,其中前 6 个为普通码,最后 1 个为特别码。例如:2019002 期 普通码为 23.13.19.45.16.12,特别码为 02,特别码为 双数,则 2019001 期 最后 1 个数为 0。又如:2019004 期 普通码为 07.31.28.06.36.43,特别码为 23,特别码为单数,则 2019003 期 最后一个数为 1。如此递推,本期最后一列的 label 值,为下一期特别码的单双值(双为 0,单为 1)。

2. 第二步,对训练模型进行调参,训练,和模型保存。

模型训练代码 ball49_xgboost_train.py 如下:

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time     : 2020/01/15 22:45
# @Author   : WanDaoYi
# @FileName : ball49_xgboost_train.py
# ============================================import os
import numpy as np
from sklearn import metrics
from sklearn.model_selection import GridSearchCV
from xgboost.sklearn import XGBClassifier
import joblib
from config import cfgimport matplotlib.pyplot as plt
import matplotlib# 用于解决画图中文乱码
font = {"family": "SimHei"}
matplotlib.rc("font", **font)class Ball49Train(object):def __init__(self):self.train_data_path = cfg.TRAIN.TRAIN_DATA_INFO_PATHself.val_data_path = cfg.TRAIN.VAL_DATA_INFO_PATHself.model_save_path = cfg.TRAIN.MODEL_SAVE_PATHself.x_train, self.y_train = self.read_data(self.train_data_path)self.x_val, self.y_val = self.read_data(self.val_data_path)self.roc_flag = cfg.TRAIN.ROC_FLAGpass# 读取数据def read_data(self, data_path):with open(data_path, "r") as file:data_info = file.readlines()data_list = [data.strip().split(".") for data in data_info]data_list = np.array(data_list, dtype=int)data_info = data_list[:, 1:-1]label_info = data_list[:, -1:]print(data_info[: 5])# ravel() 是列转行,用于解决数据转换警告。return data_info, label_info.ravel()passdef best_estimators_depth(self):# np.arange 可以生成 float 类型,range 只能生成 int 类型best_param = {'n_estimators': range(10, 201, 5),'max_depth': range(1, 20, 1)}best_gsearch = GridSearchCV(estimator=XGBClassifier(learning_rate=0.1,gamma=0,subsample=0.8,colsample_bytree=0.8,objective='binary:logistic',nthread=4,min_child_weight=5,seed=27),param_grid=best_param, scoring='roc_auc', iid=False, cv=10)best_gsearch.fit(self.x_train, self.y_train)print("best_param:{0}".format(best_gsearch.best_params_))print("best_score:{0}".format(best_gsearch.best_score_))# best_param: {'max_depth': 3, 'n_estimators': 20}# best_score: 0.5851190476190475return best_gsearch.best_params_passdef best_lr_gamma(self):# np.arange 可以生成 float 类型,range 只能生成 int 类型best_param = {'learning_rate': np.arange(0.1, 1.1, 0.1),'gamma': np.arange(0.1, 5.1, 0.2)}best_gsearch = GridSearchCV(estimator=XGBClassifier(n_estimators=20,max_depth=3,# learning_rate=0.1,# gamma=0,subsample=0.8,colsample_bytree=0.8,objective='binary:logistic',nthread=4,min_child_weight=5,seed=27),param_grid=best_param, scoring='roc_auc', iid=False, cv=10)best_gsearch.fit(self.x_train, self.y_train)print("best_param:{0}".format(best_gsearch.best_params_))print("best_score:{0}".format(best_gsearch.best_score_))# best_param: {'gamma': 1.7000000000000004, 'learning_rate': 0.5}# best_score: 0.6467261904761905return best_gsearch.best_params_passdef best_subsmaple_bytree(self):# np.arange 可以生成 float 类型,range 只能生成 int 类型# 调整subsample(行),colsample_bytree(列)best_param = {'subsample': np.arange(0.1, 1.1, 0.1),'colsample_bytree': np.arange(0.1, 1.1, 0.1)}best_gsearch = GridSearchCV(estimator=XGBClassifier(n_estimators=20,max_depth=3,learning_rate=0.5,gamma=1.7,# subsample=1.0,# colsample_bytree=0.8,objective='binary:logistic',nthread=4,min_child_weight=5,seed=27),param_grid=best_param, scoring='roc_auc', iid=False, cv=10)best_gsearch.fit(self.x_train, self.y_train)print("best_param:{0}".format(best_gsearch.best_params_))print("best_score:{0}".format(best_gsearch.best_score_))# best_param: {'colsample_bytree': 0.8, 'subsample': 0.8}# best_score: 0.6467261904761905return best_gsearch.best_params_passdef best_nthread_weight(self):# np.arange 可以生成 float 类型,range 只能生成 int 类型best_param = {'nthread': range(1, 20, 1),'min_child_weight': range(1, 20, 1)}best_gsearch = GridSearchCV(estimator=XGBClassifier(n_estimators=20,max_depth=3,learning_rate=0.5,gamma=1.7,subsample=0.8,colsample_bytree=0.8,objective='binary:logistic',# nthread=4,# min_child_weight=5,seed=27),param_grid=best_param, scoring='roc_auc', iid=False, cv=10)best_gsearch.fit(self.x_train, self.y_train)print("best_param:{0}".format(best_gsearch.best_params_))print("best_score:{0}".format(best_gsearch.best_score_))# best_param: {'min_child_weight': 5, 'nthread': 1}# best_score: 0.6467261904761905return best_gsearch.best_params_passdef best_seek(self):# np.arange 可以生成 float 类型,range 只能生成 int 类型best_param = {'seed': range(1, 1000, 1)}best_gsearch = GridSearchCV(estimator=XGBClassifier(n_estimators=20,max_depth=3,learning_rate=0.5,gamma=1.7,subsample=0.8,colsample_bytree=0.8,nthread=1,min_child_weight=5,# seed=27,objective='binary:logistic'),param_grid=best_param, scoring='roc_auc', iid=False, cv=10)best_gsearch.fit(self.x_train, self.y_train)print("best_param:{0}".format(best_gsearch.best_params_))print("best_score:{0}".format(best_gsearch.best_score_))# best_param: {'seed': 27}# best_score: 0.6467261904761905return best_gsearch.best_params_pass# 绘制 ROC 曲线def plt_roc(self, model):if self.roc_flag:y_proba = model.predict_proba(self.x_val)# 预测为 0 的概率y_zero = y_proba[:, 0]# 预测为 1 的概率y_one = y_proba[:, 1]print("AUC Score2: {}".format(metrics.roc_auc_score(self.y_val, y_one)))# 得到误判率、命中率、门限fpr, tpr, thresholds = metrics.roc_curve(self.y_val, y_one)# 计算aucroc_auc = metrics.auc(fpr, tpr)# 对ROC曲线图正常显示做的参数设定# 用来正常显示中文标签, 上面设置过# plt.rcParams['font.sans-serif'] = ['SimHei']# 用来正常显示负号plt.rcParams['axes.unicode_minus'] = Falseplt.plot(fpr, tpr, label='{0}_AUC = {1:.5f}'.format("xgboost", roc_auc))plt.title('ROC曲线')plt.xlim([-0.05, 1.05])plt.ylim([-0.05, 1.05])plt.legend(loc='lower right')plt.plot([0, 1], [0, 1], 'r--')plt.ylabel('命中率: TPR')plt.xlabel('误判率: FPR')plt.show()pass# 较好的 模型参数 进行训练def best_param_xgboost(self):best_model = XGBClassifier(n_estimators=20,max_depth=3,learning_rate=0.5,gamma=1.7,subsample=0.8,colsample_bytree=0.8,nthread=1,min_child_weight=5,seed=27,objective='binary:logistic')best_model.fit(self.x_train, self.y_train)y_pred = best_model.predict(self.x_val)acc_score = metrics.accuracy_score(self.y_val, y_pred)print("acc_score: {}".format(acc_score))print("score: {}".format(best_model.score(self.x_val, self.y_val)))save_path = self.model_save_path + "acc={:.6f}".format(acc_score) + ".m"# 判断模型是否存在,存在则删除if os.path.exists(save_path):os.remove(save_path)pass# 保存模型joblib.dump(best_model, save_path)print("AUC Score: {}".format(metrics.roc_auc_score(self.y_val, y_pred)))# 绘制 ROC 曲线self.plt_roc(best_model)passif __name__ == "__main__":demo = Ball49Train()# demo.best_estimators_depth()# demo.best_lr_gamma()# demo.best_subsmaple_bytree()# demo.best_nthread_weight()# demo.best_seek()demo.best_param_xgboost()pass

后面注释的这几个方法,

# demo.best_estimators_depth()
    # demo.best_lr_gamma()
    # demo.best_subsmaple_bytree()
    # demo.best_nthread_weight()
    # demo.best_seek()

为使用网格搜索法进行调参,从上到下,逐个放开来运行(每次只运行一个方法),得到较好的参数,代到下一个方法中去,继续调参。等把这里的参数,都调得一个较好的值之后,带入 demo.best_param_xgboost() 方法中,进行训练,并保存模型。

当然,也可以直接接收上面方法的返回值,然后,作为下一个方法的入参,一次跑完。

分开来跑的好处是,可以在每一步都观察到效果值,选择效果值好的,再进跑下一个方法,总体结果会更好点。

如果每次跑的结果,都是一样的话,那,就再重新 prepare.py 一下,因为,样本集不同,对训练也是有影响的。

运行上面的代码,得到如下所示结果:

发现,由于训练数据过少,得到的模型 的准确率并不高,只有 20% 命中而已(acc_score: 0.2)。

在实际运用中,对于二分类模型来说,并不是 准确率越低越不好。因为,在二分类问题中,只有准确率为 50% 是最不好的,其他偏离准确率为 50% 的模型越大,则模型越好。就好比,我们猜硬币,猜中正面的准确率为 20%,那么,就意味着,猜中反面的概率为 80%。这时候,我们把模型,反着来用,则得到一个较好的模型。就好像上面,20%命中的模型,反过来就是 80% 命中的模型,就目前的数据来说,还是可以的。买个彩票单双,能中个 八成。

3. 第三步,对验证集数据进行验证。

模型预测代码 ball49_xgboost_test.py 如下:

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time     : 2020/01/15 22:45
# @Author   : WanDaoYi
# @FileName : ball49_xgboost_test.py
# ============================================import numpy as np
import joblib
from config import cfgclass Ball49Test(object):def __init__(self):self.test_data_path = cfg.TEST.TEST_DATA_INFO_PATHself.data_time, self.data_info, self.true_label = self.read_data(self.test_data_path)self.acc_flag = cfg.TEST.ACC_FLAGself.model = joblib.load(cfg.TEST.MODEL_PATH)pass# 读取数据def read_data(self, data_path):with open(data_path, "r") as file:data_info = file.readlines()data_list = [data.strip().split(".") for data in data_info]data_list = np.array(data_list, dtype=int)print("predict_data_shape: {}".format(data_list.shape))# 期数data_time = data_list[:, : 1].T[0]# 开奖号码data_info = data_list[:, 1:]# 真实的特别码true_label = data_list[:, -1:].T[0]return data_time, data_info, true_labelpassdef predict_data(self):y_pred = self.model.predict(self.data_info)print("y_pred: {}".format(y_pred))pred_label = []# 如果是 acc > 0.5, 则直接使用预测值if self.acc_flag:for pred in y_pred:if pred == 1:pred_label.append("单")else:pred_label.append("双")passpass# 如果 acc < 0.5,则将预测值反过来用。因为,刚好预测反了。else:for pred in y_pred:if pred == 1:pred_label.append("双")else:pred_label.append("单")passpasstrue_label = []true_label_len = len(self.true_label)print("true_label: {}".format(self.true_label[1:]))for index in range(1, true_label_len):label = self.true_label[index]label_str = str(label)if len(label_str) == 1:label_str = "0" + label_strif label % 2 == 1:true_label.append(label_str + "-单")passelse:true_label.append(label_str + "-双")passtrue_label.append("坐等开奖收钱!...... ")time_number = self.data_time[1:]time_number = time_number.tolist()time_number.append(self.data_time[-1] + 1)for i in range(0, true_label_len):print("第 {} 期 预测结果为 ({}), 真实结果为 ({})".format(time_number[i],pred_label[i],true_label[i]))passif __name__ == "__main__":demo = Ball49Test()demo.predict_data()pass

测试数据 (将最近的一期放在最后一行,用来预测下期开奖号码。如下面的数据,我们关心的是 2020005 期开什么):

2019140  35.31.22.24.39.28+12
2019141 11.24.26.22.41.42+02
2019142 23.11.30.14.32.05+20
2019143 27.25.36.31.22.01+47
2019144 28.10.18.20.05.33+17
2020001 37.29.15.34.30.17+40
2020002 27.24.37.25.30.01+08
2020003 17.42.37.35.49.05+39
2020004 02.30.24.46.20.32+01

使用 acc=0.2 的模型(将预测结果反过来看,就是 acc=0.8 的模型)运行,结果如下:

acc=0.2 的模型下载地址:链接:https://pan.baidu.com/s/1-BsvB1ksBHtOQaVZUTL17Q 
                                                  提取码:djxm

第 2019141 期 预测结果为 (双), 真实结果为 (02-双)
第 2019142 期 预测结果为 (双), 真实结果为 (20-双)
第 2019143 期 预测结果为 (单), 真实结果为 (47-单)
第 2019144 期 预测结果为 (单), 真实结果为 (17-单)
第 2020001 期 预测结果为 (双), 真实结果为 (40-双)
第 2020002 期 预测结果为 (单), 真实结果为 (08-双)
第 2020003 期 预测结果为 (单), 真实结果为 (39-单)
第 2020004 期 预测结果为 (单), 真实结果为 (01-单)
第 2020005 期 预测结果为 (双), 真实结果为 (坐等开奖收钱!...... )

从上面的结果来看,数据量虽然少了点,但是,命中还是可以的。

这,只是一个以 49 个号码的彩票为例。道理是相通的,如果你想做其他彩票模型,也是很容易的,按照这个思路,一套带走即可。特别是,做那些 什么 5 分彩(5 分钟开一次的那种),或 快乐 10 分。这些彩票,数据量就会大上许多,准确率,也会高很多的。当然,如果不满意这个准确率,可以考虑用深度学习的全连接来做。写完这篇集成学习,我也开始进入深度学习。有机会,到时候,再考虑一下,怎么用深度学习全连接做呗。

再次重申,本章节只做学术研究,买彩票的人,亏盈自负,本人概不负责。谢谢理解。

返回主目录

返回集成学习目录

上一章:机器篇——集成学习(七) 细说 XGBoost 算法

下一章:机器篇——集成学习(九) 细说 hotel_pred 项目(酒店预测)

机器篇——集成学习(八) 细说 ball49_pred 项目(彩票预测)相关推荐

  1. 机器篇——集成学习(九) 细说 hotel_pred 项目(酒店预测)

    返回主目录 返回集成学习目录 上一章:机器篇--集成学习(八) 细说 ball49_pred 项目(彩票预测) 本小节,细说 hotel_pred 项目(酒店预测) 三. 项目解说 9. hotel_ ...

  2. (十五)集成学习(下)——蒸汽量预测

    参考:DataWhale教程链接 集成学习(上)所有Task: (一)集成学习上--机器学习三大任务 (二)集成学习上--回归模型 (三)集成学习上--偏差与方差 (四)集成学习上--回归模型评估与超 ...

  3. DataWhale集成学习Task15 集成学习案例二 (蒸汽量预测)

    集成学习案例二 (蒸汽量预测) 文章目录 集成学习案例二 (蒸汽量预测) 1 整体思路 1.1 整体步骤 1.2 评价指标 2 实战演练 导入package 加载数据 探索数据分布 特征工程 模型构建 ...

  4. 学习java的第一个实践练手项目---彩票预测系统

    这个项目用了8个晚上(20:00-23:00)加上2个白天完成. 一.所用知识点 1.java语言基础 2.多线程 3.Swing控件 4.数据库技术(MySQL) 二.项目目标 1.通过登录窗体点击 ...

  5. (三)集成学习上——偏差与方差

    参考:DataWhale教程链接 集成学习(上)所有Task: (一)集成学习上--机器学习三大任务 (二)集成学习上--回归模型 (三)集成学习上--偏差与方差 (四)集成学习上--回归模型评估与超 ...

  6. 关于集成学习的总结(一) 投票法

    最近在写那个完整的机器学习项目博客时候,我本来打算用一篇博客来写的.结果发现要写的越来越多.而且最关键的是,以前以为有些地方理解了,其实并没有理解.大概这就是写博客记笔记的好处吧...可惜我上高中初中 ...

  7. 集成学习—SGBT随机梯度提升树

    上一篇集成学习-GBDT原理理解中提到,由于GBDT的弱学习器之间存在依赖关系,难以并行训练数据,因此若数据量较大时程序运行太慢.这里可以通过加入了自采样的SGBT来达到部分并行,这是一个能改善GBD ...

  8. 集成学习—Adaboost(理解与应用)

    在上一篇集成学习-Adaboost(论文研读)中已经将Adaboost的原始论文精读了一遍,这篇博客主要是对Adaboost算法(主要是二分类的Adaboost)进行更深入的理解和推导,以及尝试下关于 ...

  9. 集成学习—Adaboost(论文研读)

    这篇博客主要是对Adaboost算法的论文精度,包括翻译以及自己的一些基本理解,如果对原论文不感兴趣,只是想快速理解与应用,可以参考另外一篇集成学习-Adaboost(理解与应用) Adaboost是 ...

最新文章

  1. android中static方法,StaticLayout如何在Android中使用?
  2. python怎么安装requests库-Python3.6安装及引入Requests库
  3. spark环境搭建(idea版本)
  4. 《YOLO算法笔记》(草稿)
  5. 域中计算机与用户,域内计算机和用户获取实现vbs代码
  6. 开发人员必备的 Chrome 扩展
  7. 【渝粤题库】陕西师范大学163212旅游地理学 作业(专升本)
  8. 判断app访问还是web访问网站
  9. 【NOIP2016提高A组模拟9.14】数列编辑器
  10. matlab有限单元法计算桁架算例代码
  11. 十分钟掌握Google Guice(上)
  12. opencms 发布过程深入研究
  13. 实现自定义Sql 注入器
  14. 图音80系列车载导航/DVD分体机安装DSA
  15. 关于毕业:三方协议、派遣证、干部身份等常识
  16. 阿里中台搞了3年,凉了?网传:副总裁玄难“背锅”,辞职创业!咸鱼放弃维护 Flutter!...
  17. android 支付宝登录无法返回
  18. mysql小计_Mysql必读用SQL实现统计报表中的小计与合计的方法详解
  19. 看图吧地图数据如何玩转企业地信圈
  20. CSS面试题整理汇总

热门文章

  1. MATLAB中tf函数的使用
  2. 图片上传预览,解决路径为fakepath
  3. 分析大众点评产品成败的因素,从pest和波特五力出发。
  4. python module ‘mitmproxy.proxy‘ has no attribute ‘config‘问题解决
  5. Java2实用教程2(第五版)耿祥义课后习题参考答案
  6. 银行家算法 C语言实现
  7. 15年美亚杯电子取证
  8. Kindeditor编辑器使用
  9. Android之关于ComponentName的参数实测
  10. Python3 range()函数的替代品——xrange()的作用