GBDT模型 0基础小白也能懂(附代码)

news/2024/9/26 5:23:47

GBDT模型 0基础小白也能懂(附代码)

原文链接

啥是GBDT

GBDT(Gradient Boosting Decision Tree),全名叫梯度提升决策树,是一种迭代的决策树算法,又叫 MART(Multiple Additive Regression Tree),它通过构造一组弱的学习器(树),并把多颗决策树的结果累加起来作为最终的预测输出。该算法将决策树与集成思想进行了有效的结合。

Gradient Boosting里的boosting是啥?

Boosting方法训练基分类器时采用串行的方式,各个基分类器之间有依赖。它的基本思路是将基分类器层层叠加,每一层在训练的时候,对前一层基分类器分错的样本,给予更高的权重(并根据前一个基分类器的表现计算误差。对于分类正确的样本,它们的权重保持不变或减少;对于分类错误的样本,算法会增加它们的权重)。测试时,根据各层分类器的结果的加权得到最终结果。

Bagging 与 Boosting 的串行训练方式不同,Bagging 方法在训练过程中,各基分类器之间无强依赖,可以进行并行训练。

GBDT详解

所有弱分类器的结果相加等于预测值。
每次都以当前预测为基准,下一个弱分类器去拟合误差函数对预测值的残差(预测值与真实值之间的误差)。
GBDT的弱分类器使用的是树模型。

实际工程实现里,GBDT 是计算负梯度,用负梯度近似残差,而不是像这样简单相减

1)GBDT与负梯度近似残差

回归任务下,GBDT在每一轮的迭代时对每个样本都会有一个预测值,此时的损失函数为均方差损失函数:

可以看出,当损失函数选用「均方误差损失」时,每一次拟合的值就是(真实值-预测值),即残差。

2)GBDT训练过程

我们来借助1个简单的例子理解一下 GBDT 的训练过程。假定训练集只有4个人\(A,B,C,D\),他们的年龄分别是\((14,16,24,26)\)。身份分别是高一学生,高三学生,应届毕业生,已工作两年。为了按照特征预测年龄

先用回归树训练后看结果,里面分的节点只是例子,没啥具体含义,先按购物金额分一下,再按上网时间或者上网时段分一下

接下来改用 GBDT 来训练。由于样本数据少,我们限定叶子节点最多为2即每棵树都只有一个分枝),并且限定树的棵树为2。先按照购物金额来分出一棵树:

上图中的树很好理解:\(A,B\) 年龄较为相近,\(C,D\) 年龄较为相近,被分为左右两支,每支用平均年龄作为预测值。

  • 我们计算残差(即「实际值」-「预测值」),所以\(A\)的残差 \(15-1=14\)
  • 这里 \(A\) 的「预测值」是指前面所有树预测结果累加的和,在当前情形下前序只有一棵树,所以直接是 \(15\),其他多树的复杂场景下需要累加计算作为 \(A\) 的预测值。

那么到这里预测完成,接下来就是要用一个弱分类器(树)去拟合误差函数对预测值的残差(预测值与真实值之间的误差)

上图中的树就是残差学习的过程了,里面提问和回答同样也只是例子:

  • \(A,B,C,D\) 的值换作残差 \(-1,1,-1,1\),再构建一棵树学习,这棵树只有两个值 \(1\)\(-1\),直接分成两个节点:\(A,C\) 在左边,\(B,D\) 在右边。
  • 这棵树学习残差,在我们当前这个简单的场景下,已经能保证预测值和实际值(上一轮残差)相等了。
  • 我们把这棵树的预测值累加到第一棵树上的预测结果上,就能得到真实年龄,这个简单例子中每个人都完美匹配,得到了真实的预测值。

最终的预测过程是这样的:

  • \(A\):高一学生,购物较少,经常问学长问题,真实年龄 14 岁,预测年龄 15-1
  • \(B\):高三学生,购物较少,经常被学弟提问,真实年龄 16 岁,预测年龄 15+1
  • \(C\):应届毕业生,购物较多,经常问学长问题,真实年龄 24 岁,预测年龄 25-1
  • \(D\):工作两年员工,购物较多,经常被学弟提问,真实年龄 26 岁,预测年龄 25+1

综上,GBDT 需要将多棵树的得分累加得到最终的预测得分,且每轮迭代,都是在现有树的基础上,增加一棵新的树去拟合前面树的预测值与真实值之间的残差。

梯度提升 vs 梯度下降

下面我们来对比一下「梯度提升」与「梯度下降」。这两种迭代优化算法,都是在每1轮迭代中,利用损失函数负梯度方向的信息,更新当前模型,只不过:

梯度提升(比如前面的GBDT):通过构建多个弱学习器,在函数空间中逼近最优解,不需要模型参数化,利用损失函数的负梯度逐步优化模型。

梯度下降(比如线性回归,逻辑回归):通过参数化的模型,利用损失函数的梯度更新参数,最终找到使损失最小的参数。

优缺点

随机森林 vs GBDT

何时使用哪个模型?

使用随机森林:如果你需要快速构建一个可用的模型,并且数据量较大,可以考虑使用随机森林。它对参数的敏感性较低,容易调参,适用于初步探索。
使用GBDT:如果你在追求更高的模型精度,尤其是在较复杂的数据集上,GBDT 通常表现更好,但需要更多的时间来调整超参数。

代码实现

还是用加州房价数据集

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score# 1. 加载加州房价数据集
data = fetch_california_housing()
X = data.data  # 特征矩阵
y = data.target  # 目标变量(房价)# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 3. 创建GBDT回归模型
# n_estimators: 基分类器的数量
# learning_rate: 每个基分类器的学习率
# max_depth: 决策树的最大深度
gbdt_regressor = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)# 4. 训练模型
gbdt_regressor.fit(X_train, y_train)# 5. 进行预测
y_pred_train = gbdt_regressor.predict(X_train)
y_pred_test = gbdt_regressor.predict(X_test)# 6. 评估模型# 计算测试集的均方误差
mse_test = mean_squared_error(y_test, y_pred_test)
print(f"Mean Squared Error (Test): {mse_test:.2f}")# 计算训练集的均方误差
mse_train = mean_squared_error(y_train, y_pred_train)
print(f"Mean Squared Error (Train): {mse_train:.2f}")# 计算R²得分
r2_test = r2_score(y_test, y_pred_test)
r2_train = r2_score(y_train, y_pred_train)
print(f"R² Score (Test): {r2_test:.2f}")
print(f"R² Score (Train): {r2_train:.2f}")# 7. 可视化GBDT的预测结果(实际值 vs. 预测值)
plt.scatter(y_test, y_pred_test, alpha=0.5)
plt.xlabel('Actual Prices')
plt.ylabel('Predicted Prices')
plt.title('Actual vs Predicted Prices (GBDT)')
plt.show()

结果如下

Mean Squared Error (Test): 0.29
Mean Squared Error (Train): 0.26
R² Score (Test): 0.78
R² Score (Train): 0.81

和我们之前的决策树相比明显好了很多,毕竟这里数据规模不大。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/56558.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

日程

每日日程 2024.8.6周二 / 二零二四年(甲辰年)七月初三蓝色字体为跳转链接(按下Ctrl+点击蓝色字体链接即可完成跳转) 个人博客地址:GitHub:Madman-luxun (madman) // 博客园:XG_madman c++黑马程序员笔记:C++基础编程 | welcome to here (namic00.github.io) 野火开发…

二项式定理(二项式展开)

目录引入正题延伸 引入 首先有一个广为人知的结论: \[(a+b)^2=a^2+2ab+b^2 \]那么,如何求 \((a+b)^3\) 呢?手算,如下: \[\begin{aligned} (a+b)^3 &= (a+b)\times(a+b)^2\\ &=(a+b)\times(a^2+2ab+b^2)\\ &=[a\times(a^2+2ab+b^2)]+[b\times(a^2+2ab+b^2)]\\ …

机器视觉检测的速度六大影响因素

物料处理时间 材料处理时间是指待检测材料暴露在图像采集介质前面,以便能够充分聚焦在材料上以获取图像的时间。在工业环境中,材料通常位于装配线或传送带上。相机是固定的或可移动的,放置在装配线的某个点。当材料进入相机的焦点区域时,材料处理时间开始,当材料完全聚焦时…

2024-09-04:用go语言,给定一个长度为n的数组 happiness,表示每个孩子的幸福值,以及一个正整数k,我们需要从这n个孩子中选出k个孩子。 在筛选过程中,每轮选择一个孩子时,所有尚未选

2024-09-04:用go语言,给定一个长度为n的数组 happiness,表示每个孩子的幸福值,以及一个正整数k,我们需要从这n个孩子中选出k个孩子。 在筛选过程中,每轮选择一个孩子时,所有尚未选中的孩子的幸福值都会减少 1。需要注意的是,幸福值不能降低到负数,只有在其为正数时才能…

初探编译链接原理

这篇博文由一个 bug 引出了编译链接的整个过程。我们可以看到一个源代码文件最终变成一个可执行文件中间经历了编译和链接两个过程,编译过程又分为预编译,编译,和汇编;预编译阶段主要处理#开头的代码,编译则是进行一些语法分析和优化,最终生成汇编代码,而汇编则是生成机…

canvas版本的五子棋

代码:<!Doctype html> <html lang="zh_cn"><head><meta http-equiv="Content-Type" content="text/html; charset=utf-8" /><title>五子棋</title><meta name="Keywords" content="&quo…

vue router路径重复时报错

参考——  https://blog.csdn.net/zz00008888/article/details/119566375 报错: Avoided redundant navigation to current location: "/Eee". NavigationDuplicated: Avoided redundant navigation to current location: "/Eee".在router的index下添加…