5 CatBoost模型

目录

1 背景

2 原理

2.1 类别特征处理

2.1.1 传统目标编码: TS

2.1.2 Greedy TS

2.1.3 ordered TS编码

2.1.4 CatBoost处理Categorical features总结

2.2.预测偏移处理

2.2.1 梯度无偏估计

2.3 树的构建​​​​​​​

3 优缺点

优点

4 代码


1 背景

        终于到了CatBoost,这个模型我在打比赛的时候相对于lightGBM用的少一些,但是我一般都会进行尝试,尤其是类别型特征很多的时候。曾依赖这个模型单模干到了top8,那时候还不懂的继承的优雅。我们开始这个模型介绍吧。

        CatBoost是一种基于对称决策树(oblivious trees)为基学习器实现的参数较少、支持类别型变量和高准确性的GBDT框架,主要解决的痛点是高效合理地处理类别型特征,这一点从它的名字中可以看出来,CatBoost是由Categorical和Boosting组成。此外,CatBoost还解决了梯度偏差(Gradient Bias)以及预测偏移(Prediction shift)的问题,从而减少过拟合的发生,进而提高算法的准确性和泛化能力。

        上面做了一个概述,那么黑体的名字如何理解呢?

        另外,与其提升算法不同,CatBoost使用对称全二叉树(这种树的特点是每一层使用相同的分割特征)。这样一来,树是更简单的结构,我们也就避免了过度拟合的危险。此外,由于我们的基础模型结构简单,我们有更快的预测器。

                           

2 原理

2.1 类别特征处理

        CatBoost算法的设计初衷是为了更好的处理GBDT特征中的categorical features(比如性别【男,女】)。在处理 GBDT特征中的categorical features的时候,最简单的方法是用 categorical feature 对应的标签的平均值来替换(target encoding, 这个在比赛中我也是常用,但是存在问题)。在决策树中,标签平均值将作为节点分裂的标准。这种方法被称为 Greedy Target-based Statistics , 简称 Greedy TS;

2.1.1 传统目标编码: TS

用公式来表达就是:

x_{i,k} = \frac{ \sum_{j=1}^{n} [x_{j,k}=x_{i,k}] Y_j}{\sum_{j=1}^{n} [x_{j,k}=x_{i,k}]},  =>  groupby('cat')[label].mean()

        TS编码有一个缺点,极端情况下当训练集中某类取值只有一个样本、或者没有样本时,计算的编码值就失真了,也就是容易受噪声数据影响。

        如果强行用标签的平均值来表示特征的话,当训练数据集和测试数据集数据结构和分布不一样的时候会出条件偏移问题。

2.1.2 Greedy TS

        一个标准的改进 TS的方式是添加先验分布项,这样可以减少噪声和低频率类别型数据对于数据分布的影响

直接上公式:

x_{i,k} = \frac{ \sum_{j=1}^{n} [x_{j,k}=x_{i,k}] Y_j + ap}{\sum_{j=1}^{n} [x_{j,k}=x_{i,k}]+a}

        其中p是添加的先验项,a通常是大于0的权重系数。添加先验项是一个普遍做法,针对类别数较少的特征,它可以减少噪声数据。对于回归问题,一般情况下,先验项可取数据集label的均值。对于二分类,先验项是正例的先验概率。

        Greedy TS编码也存在一个问题,即目标泄露。也需要训练预测集合数据分布一致;

2.1.3 ordered TS编码

        它是catboost的主要思想,依赖于排序,受online learning algorithms的启发得到,对于某一个样本,TS的值依赖于观测历史,为了在离线的数据上应用该思想,我们将数据随机排序,对于每一个样本,利用该样本之前数据计算该样本类别值的TS值。如果仅仅使用一个随机序列,那么计算得到值会有较大的方差,因此我们使用不同的随机序列来计算。

        在某种排序状态 𝜎 下,样本 x_i 在分类特征 𝑘 下的值为 x_{i}^{k}x_{i}^{k}的ordered TS编码值是基于排在其前面的样本D_{\sigma }计算的,在D_{\sigma }中计算分类特征 𝑘 下取值与x_{i}^{k}相同的样本的Greedy TS编码值,该值即为的ordered TS编码值。举例说明

在上图中,经过样本排序后,样本的排序情况为{4,3,7,2,6,1,5},计算样本4的ordered TS编码值时,由于没有样本排在其前面,因此其ordered TS编码值计算方式为

T(y=1|x=D) = \frac{0+ap}{0+a} = p

计算样本6的ordered TS编码值时,排在其前面的样本为{4,3,7,2},在这4个样本中,特征取值为D的只有样本4,因此其ordered TS编码值计算方式为:

T(y=1|x=D) = \frac{0+ap}{0+a} = p

2.1.4 CatBoost处理Categorical features总结

  • 首先会计算一些数据的statistics。计算某个category出现的频率,加上超参数,生成新的numerical features。这一策略要求同一标签数据不能排列在一起(即先全是之后全是这种方式),训练之前需要打乱数据集。
  • 第二,使用数据的不同排列。在每一轮建立树之前,先扔一轮骰子,决定使用哪个排列来生成树。
  • 第三,考虑使用categorical features的不同组合。例如颜色和种类组合起来,可以构成类似于blue dog这样的特征。当需要组合的categorical features变多时,CatBoost只考虑一部分combinations。在选择第一个节点时,只考虑选择一个特征,例如A。在生成第二个节点时,考虑A和任意一个categorical feature的组合,选择其中最好的。就这样使用贪心算法生成combinations。
  • 第四,除非向gender这种维数很小的情况,不建议自己生成One-hot编码向量,最好交给算法来处理。

2.2.预测偏移处理

        在GBDT算法中,每一棵树都是为了拟合前一棵树上的梯度,构造树时所有的样本都参与了,一个样本参与了建树,然后又用这棵树去估计样本值,这样的估计就不是无偏估计,当测试集和训练集上的样本分布不一致时,模型就会因过拟合而性能不佳,即在测试集上产生了预测偏移。

2.2.1 梯度无偏估计

        对于样本 𝑥𝑖 ,如果用一个不包含它的模型去估计它的梯度,估计结果可以视为无偏估计。基于这种思路,CatBoost算法中采用了如下策略:在每一轮迭代时,将样本集排序,然后训练 𝑛 个模型 𝑀𝑖,𝑖=1,...,𝑛 ,𝑛为样本数量,其中 𝑀𝑖 是由前 𝑖 个样本训练得到(基于本轮样本的排序,包含样本𝑖),然后估计样本𝑖的梯度时,使用模型𝑀𝑖−1 来估计,因为𝑀𝑖−1是由不包含样本𝑖的样本训练得到,因此该估计结果是无偏估计。

                        

        这种方式尽管得到了无偏估计,但是对于排序靠前的样本,它的梯度估计结果可能并不太准确,具有较大的方差,因为估计它的模型是由较少的样本训练的,而且是基于本轮迭代中的样本排序计算的,为了减少预测方差,CatBoost在每轮迭代中都会对样本进行重新排序,然后按照相同的思路估计本轮中样本的梯度,这样多轮迭代的最终结果就可以获取一个较小的方差。

        当然,如果每轮迭代都要训练𝑛 个模型,那是一个比较大的工作量,CatBoost算法对这个过程做了些简化来提升训练速度,详情后面分析树的创建过程时再说明。

2.3 树的构建

        得到树结构,也是节点分裂的过程。在GBDT、XGBoost、LightGBM等算法中,节点分类时需要遍历所有候选特征及分裂阈值,CatBoost算法也采用了这种策略,但有两点不同:

  • 分类型特征处理;
  • 数值型特征的空值处理。在CatBoost算法中,将空值全部转换为最小值(默认),或者最大值;

        每一轮迭代都会创建 𝑛 个模型,也就是创建𝑛棵树(在CatBoost算法中,并不是𝑛棵树,而是 [𝑙𝑜𝑔2𝑛] ,这个只是为了减少计算量,因此在阐述CatBoost算法的原理时仍然使用𝑛),训练树是比较耗时的过程,在整个模型训练的时间中占很大的比重,如果每一轮都要重新训练𝑛棵树,那将会非常耗时,CatBoost算法在这点上做了调整,具体做是:

1 在第一轮迭代中,在选定样本排序状态 𝜎 下,分别用前 𝑗 个样本训练模型 𝑀𝜎,𝑗 ,也就是得到 𝑛 棵树。一个树的训练过程包含两部分:第一步,得到树结构;第二步,计算叶子节点的值。第一步中得到的树结构,也即是每一层选用什么分裂特征,分裂阈值是多少

2 得到模型𝑀𝜎,𝑗的树结构后,CatBoost会使将该树结构复用到后续的所有迭代过程中。例如在第一轮迭代中,由前100个样本训练得到了模型 𝑀𝜎,100 的树结构,后续的迭代过程中会直接使用该树结构,然后将对应排序状态 𝜎′ 下的前100个样本直接应用该树结构,将样本划分到对应的叶子节点上,得到完整的模型,而不用再重新遍历样本的特征来寻找最佳划分特征和划分阈值。

        ​​​​​​​        ​​​​​​​        

那么在每一层中,如何评判用哪种特征分裂最好呢?CatBoost算法采用这样的策略:

  1. 基本本轮样本排序状态,得到每一棵树上的样本,首先依据前面迭代轮次的结果计算每个样本的梯度,也就是得到根节点中每个样本的梯度向量 𝐺 。
  2. 遍历候选特征来分裂根节点,得到多种分裂结果,分别计算每个样本的叶子值增量 Δ(𝑖) ,其中 𝑖 指第 𝑖 个样本(这个增量值的含义原论文中也没有明确说明,只是说是一个增量值,这里也参照原文进行说明),Δ(𝑖)的计算方式为:在样本 𝑖 所属的叶子节点上,计算排在样本 𝑖 前面的样本的梯度的平均值,该平均值即为Δ(𝑖)。(这里要说明一下,上面说的叶子节点,是相对于上一层的节点来说的,并不是整棵树的叶子节点),这样就得到每个样本的增量值Δ(𝑖),这些增量值可以组成一个向量 Δ;

  3. 采用余弦相似度方法,计算各种候选特特征下节点分裂的损失 𝑙𝑜𝑠𝑠(𝐺,Δ) ,选择损失最小的方式来分裂树。余弦相似度的计算方式为:

  4.                               

  5. 其中 𝜔𝑖 标识样本𝑖的权重,这个权重是CatBoost算法中为每个样本随机赋予的,起到样本抽样的效果,目的是为了减少过拟合; 𝑎𝑖 标识样本在树上的输出值,这里也就是Δ(𝑖)值; 𝑔𝑖 即第一步中计算的样本的梯度值。

3 优缺点

优点

  • 性能卓越: 在性能方面可以匹敌任何先进的机器学习算法;
  • 鲁棒性/强健性: 无需调参即可获得较高的模型质量,采用默认参数就可以获得非常好的结果,减少在调参上面花的时间,减少了对很多超参数调优的需求
  • 易于使用: 提供与scikit集成的Python接口,以及R和命令行界面;
  • 实用: 可以处理类别型、数值型特征,支持类别型变量,无需对非数值型特征进行预处理
  • 可扩展: 支持自定义损失函数;
  • 快速、可扩展的GPU版本,可以用基于GPU的梯度提升算法实现来训练你的模型,支持多卡并行提高准确性,
  • 快速预测:即便应对延时非常苛刻的任务也能够快速高效部署模型;

缺点:

  • 对于类别型特征的处理需要大量的内存和时间;
  • 不同随机数的设定对于模型预测结果有一定的影响;

4 代码

import re
import os
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve,roc_auc_score
import matplotlib.pyplot as plt
import gc
from bayes_opt import BayesianOptimization
from catboost import Pool, cv

n_fold = 5
folds = KFold(n_splits=n_fold, shuffle=True, random_state=1314)
oof = np.zeros(len(train_df))
prediction = np.zeros(len(test_df))
for fold_n, (train_index, valid_index) in enumerate(folds.split(train_df)):
    X_train, X_valid = train_df[features].iloc[train_index], train_df[features].iloc[valid_index]
    y_train, y_valid = train_df[label].iloc[train_index], train_df[label].iloc[valid_index]
    cate_features=[]
#     +['corss_卧室_床的数量', 'corss_床的类型_床的数量',
#              'corss_房产类型_卧室数量', 'corss_房产类型_洗手间数量']
    train_pool = Pool(X_train, y_train, cat_features=cate_features)
    eval_pool = Pool(X_valid, y_valid, cat_features=cate_features)
    cbt_model = catboost.CatBoostRegressor(iterations=600, # 注:baseline 提到的分数是用 iterations=60000 得到的,但运行时间有点久
                           learning_rate=0.1, # 注:事实上好几个 property 在 lr=0.1 时收敛巨慢。后面可以考虑调大
                           eval_metric='mse',
                         # n_estimators=3000,
                           # reg_lambda=5,
                           use_best_model=True,
                           random_seed=42,
                           logging_level='Verbose',
                           #task_type='GPU',
                           devices='0',
                           gpu_ram_part=0.5)
    
    cbt_model.fit(train_pool,
              eval_set=eval_pool,
              verbose=1000)
    
    y_pred_valid = cbt_model.predict(X_valid)
    print("valid RMSE")
    print(print(np.sqrt(np.mean(np.square(y_pred_valid - train_df.loc[valid_index,label])))))
    y_pred = cbt_model.predict(test_df[features])
    oof[valid_index] = y_pred_valid.reshape(-1, )
    prediction += y_pred
prediction /= n_fold

from sklearn.metrics import mean_squared_error
score = mean_squared_error(oof, train_df[label].values)
print(score)


#test['价格'] = prediction
#test[['数据ID', '价格']].to_csv('./{}_sub_cat.csv'.format(np.sqrt(score)), index=None)

ref:

 安全验证 - 知乎

安全验证 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/557309.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony鸿蒙南向开发案例:【智能门铃】

样例简介 智能门铃通过监控来访者信息,告诉主人门外是否有人按铃、有陌生人靠近或者无人状态。主人可以在数字管家中远程接收消息,并根据需要进行远程取消报警和一键开锁。同时,也可以通过室内屏幕获取门外状态。室内屏幕显示界面使用DevEco…

【创建型模式】单例模式

一、单例模式概述 单例模式的定义:又叫单件模式,确保一个类只有一个实例,并提供一个全局访问点。(对象创建型) 要点: 1.某个类只能有一个实例;2.必须自行创建这个实例;3.必须自行向整…

C语言 | 动态内存管理

目录: 1. 为什么要有动态内存分配 2. malloc和free 3. calloc和realloc 4. 常见的动态内存的错误 5. 动态内存经典笔试题分析 6. 柔性数组 1. 为什么要有动态内存分配 我们已经掌握的内存开辟方式有: int val 20; //在栈空间上开辟四个字节 cha…

MR-JE-70A 三菱MR-JE伺服驱动器(750W通用型)

三菱MR-JE伺服驱动器(750W通用型) MR-JE-70A外部连接,MR-JE-70A用户手册,MR-JE-70A 三相或单相AC220V三菱通用型伺服放大器750W,配套电机HG-SN52J-S100、HG-KN73J-S100。 MR-JE-70A参数说明:伺服驱动器通用型750W,三相或单相AC200V~240V 三…

C语言野指针【入门详解】

目录 一、什么是野指针 二、野指针的成因 2.1 指针未初始化 2.2 指针越界访问 2.3 指针指向的空间释放 三、如何规避野指针 3.1 初始化指针 3.2 小心越界访问 3.3 当指针不用时,及时置为空 3.4 避免返回局部变量的地址 *结语: 希望这篇关于指…

IM即时通讯软件,WorkPlus私有化部署全面支持信创环境

在数字化转型的浪潮中,政企单位对即时通讯(IM)软件的需求日益增长。然而,随着信息化程度的提高,数据安全和信息泄露风险也日益凸显。在这样的背景下,WorkPlus作为一款私有化部署的IM即时通讯软件,以其在安全性、管理便…

汇编语言——将BX中的无符号数和有符号数以二进制、八进制、十六进制、十进制形式输出

文章目录 将BX中的无符号数以二进制形式输出将BX中的无符号数以八进制形式输出将BX中的无符号数以十六进制形式输出将BX中的无符号数以十进制形式输出将BX中的有符号数以十进制形式输出 将BX中的无符号数以二进制形式输出 利用移位指令会影响CF,默认dl30h(数字0)&a…

时序深入之CPR(Clock Pessimism Removal)详解

目录 一、CPR概念 二、CPR的计算 三、CPR的开启关闭 四、CPR为0 ​五、参考资料 一、CPR概念 在时序报告的目标时钟路径中,会有一行数据clock pesssimism,第一次见可能都会对这个概念感到疑惑 同样在每条时序路径的summary中,clock pat…

自动化测试Selenium(4)

WebDriver相关api 定位一组元素 webdriver可以很方便地使用findElement方法来定位某个特定的对象, 不过有时候我们需要定位一组对象, 这时候就要使用findElements方法. 定位一组对象一般用于一下场景: 批量操作对象, 比如将页面上的checkbox都勾上. 先获取一组对象, 再在这组…

[最新]访问/加速StackOverFlow的方法

但是有很多问题都是在StackOverFlow上有现成的解决方案,而某度搜索引擎…前一页的回答互相抄袭,看着实在胀眼睛。 话不多说,解决办法: 直接访问插件商店下载插件(最快捷方便,点点就行)&#x…

Linux系统编程——权限概念和权限管理

目录 一,关于Shell 1.1 外壳程序 1.2 shell的作用 1.3 shell运行原理 二,权限概念 2.1 用户与权限 2.2 su(用户切换指令) ​编辑 2.3 提升指令权限和信任名单 三,文件权限 3.1 关于文件权限 3.2 文件访问者…

UG10.如何设置鼠标滚轮操作模型放大缩小方向?

UG10.如何设置鼠标滚轮操作模型放大缩小方向呢?看一下具体操作步骤吧。 首先打开UG10.软件,在主菜单栏选择【文件】下拉菜单,选择【实用工具】。 点击【用户默认设置】。 文章源自四五设计网-https://www.45te.com/45545.html 选中【基本环…

python语言零基础入门——变量与简单数据类型

目录 一、变量 1.创建变量 2.变量的修改 3.变量的命名 (1)常量 (2)标识符 (3)关键字 (4)命名规则 二、简单数据类型 1.变量的数据类型 2.数据类型 3.整型(In…

中断的设备树修改及上机实验(按键驱动)流程

写在前面的话:对于 GPIO 按键,我们并不需要去写驱动程序,使用内核自带的驱动程序 drivers/input/keyboard/gpio_keys.c 就可以,然后你需要做的只是修改设备树指定引脚及键值。 根据驱动文件中的platform_driver中的.of_match_tabl…

C++之类和对象三

目录 拷贝构造函数 定义铺垫 浅拷贝 深拷贝 总结 拷贝构造函数 那在创建对象时,可否创建一个与一个对象一某一样的新对象呢? 定义铺垫 构造函数:只有单个形参,该形参是对本类类型对象的引用(一般常用const修饰)&#xff0c…

2024年华中杯B题论文发布+数据预处理问题一代码免费分享

【腾讯文档】2024年华中杯B题资料汇总 https://docs.qq.com/doc/DSExMdnNsamxCVUJt 行车轨迹估计交通信号灯周期问题 摘要 在城市化迅速发展的今天,交通管理和优化已成为关键的城市运营问题之一。本文将基于题目给出的数据,对行车轨迹估计交通信号灯…

【1577】java网吧收费管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java 网吧收费管理系统是一套完善的java web信息管理系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发,数据库为Mysql5.0…

一篇文章搞定Jenkins自动化部署JDK17+SpringBoot3.X+新版AlibabaCloud打包Docker镜像推送私有镜像仓库

🚀 作者 :“二当家-小D” 🚀 博主简介:⭐前荔枝FM架构师、阿里资深工程师||曾任职于阿里巴巴担任多个项目负责人,8年开发架构经验,精通java,擅长分布式高并发架构,自动化压力测试,微服务容器化k…

Redis中的订阅发布(二)

订阅与发布 订阅频道 每当客户端执行SUBSCRIBE命令订阅某个或某些频道的时候,服务器都会将客户端与被订阅的频道 在pubsub_channels字典中进行关联。 根据频道是否已经有其他订阅者,关联操作分为两种情况执行: 1.如果频道已经有其他订阅者&#xff0c…

微信小程序echart图片不显示 问题解决

目录 1.问题描述:2.解决方法:2.1第一步2.2第二步2.2效果 小结: 1.问题描述: echart图片不显示 图片: 2.解决方法: 2.1第一步 给wxml中的ec-canvas组件添加宽高样式:style"width: 100%…
最新文章