博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
xgboost回归损失函数自定义【一】
阅读量:6113 次
发布时间:2019-06-21

本文共 3607 字,大约阅读时间需要 12 分钟。

hot3.png

写在前面:

每当提到损失函数,很多人都有个误解,以为用在GridSearchCV(网格搜索交叉验证“Cross Validation”)里边的scoring就是损失函数,其实并不是。我们使用构造函数构造XGBRegressor的时候,里边的objective参数才是真正的损失函数(loss function)。xgb使用sklearn api的时候需要用到的损失函数,其返回值是一阶导和二阶导,而GridSearchCV使用的scoring函数,返回的是一个float类型的数值评分(或叫准确率、或叫偏差值)。

You should be careful with the notation.

There are 2 levels of optimization here:

  1. The loss function optimized when the XGBRegressor is fitted to the data.
  2. The scoring function that is optimized during the grid search.

I prefer calling the second scoring function instead of loss function, since loss function usually refers to a term that is subject to optimization during the model fitting process itself.

因此,下文对于objective,统一叫“目标函数”;而对scoring,统一叫“评价函数”。

 

 

========== 原文分割线 ===================

许多特定的任务需要定制目标函数,来达到更优的效果。这里以xgboost的回归预测为例,介绍一下objective函数的定制过程。一个简单的例子如下:

def customObj1(real, predict):    grad = predict - real    hess = np.power(np.abs(grad), 0.5)    return grad, hess

网上有许多教程定义的objective函数中的第一个参数是preds,第二个是dtrain,而本文由于使用xgboost的sklearn API,因此定制的objective函数需要与sklearn的格式相符。调用目标函数的过程如下:

model = xgb.XGBRegressor(objective=customObj1,                         booster="gblinear")

下面是不同迭代次数的动画演示:

我们发现,不同的目标函数对模型的收敛速度影响较大,但最终收敛目标大致相同,如下图:

完整代码如下:

# coding=utf-8import pandas as pdimport numpy as npimport xgboost as xgbimport matplotlib.pyplot as pltplt.rcParams.update({'figure.autolayout': True})df = pd.DataFrame({'x': [-2.1, -0.9,  0,  1,  2, 2.5,  3,  4],                   'y': [ -10,    0, -5, 10, 20,  10, 30, 40]})X_train = df.drop('y', axis=1)Y_train = df['y']X_pred = [-4, -3, -2, -1, 0, 0.4, 0.6, 1, 1.4, 1.6, 2, 3, 4, 5, 6, 7, 8]def process_list(list_in):    result = map(lambda x: "%8.2f" % round(float(x), 2), list_in)    return list(result)def customObj3(real, predict):    grad = predict - real    hess = np.power(np.abs(grad), 0.1)    # print 'predict', process_list(predict.tolist()), type(predict)    # print ' real  ', process_list(real.tolist()), type(real)    # print ' grad  ', process_list(grad.tolist()), type(grad)    # print ' hess  ', process_list(hess.tolist()), type(hess), '\n'    return grad, hessdef customObj1(real, predict):    grad = predict - real    hess = np.power(np.abs(grad), 0.5)    return grad, hessfor n_estimators in range(5, 600, 5):    booster_str = "gblinear"    model = xgb.XGBRegressor(objective=customObj1,                             booster=booster_str,                             n_estimators=n_estimators)    model2 = xgb.XGBRegressor(objective="reg:linear",                              booster=booster_str,                              n_estimators=n_estimators)    model3 = xgb.XGBRegressor(objective=customObj3,                              booster=booster_str,                              n_estimators=n_estimators)    model.fit(X=X_train, y=Y_train)    model2.fit(X=X_train, y=Y_train)    model3.fit(X=X_train, y=Y_train)    y_pred = model.predict(data=pd.DataFrame({'x': X_pred}))    y_pred2 = model2.predict(data=pd.DataFrame({'x': X_pred}))    y_pred3 = model3.predict(data=pd.DataFrame({'x': X_pred}))    plt.figure(figsize=(6, 5))    plt.axes().set(title='n_estimators='+str(n_estimators))    plt.plot(df['x'], df['y'], marker='o', linestyle=":", label="Real Y")    plt.plot(X_pred, y_pred, label="predict - real; |grad|**0.5")    plt.plot(X_pred, y_pred3, label="predict - real; |grad|**0.1")    plt.plot(X_pred, y_pred2, label="reg:linear")    plt.xlim(-4.5, 8.5)    plt.ylim(-25, 55)    plt.legend()    # plt.show()    plt.savefig("output/n_estimators_"+str(n_estimators)+".jpg")    plt.close()    print(n_estimators)

 

转载于:https://my.oschina.net/u/2996334/blog/3006786

你可能感兴趣的文章
注解开发
查看>>
如何用 Robotframework 来编写优秀的测试用例
查看>>
Django之FBV与CBV
查看>>
Vue之项目搭建
查看>>
app内部H5测试点总结
查看>>
Docker - 创建支持SSH服务的容器镜像
查看>>
[TC13761]Mutalisk
查看>>
三级菜单
查看>>
Data Wrangling文摘:Non-tidy-data
查看>>
加解密算法、消息摘要、消息认证技术、数字签名与公钥证书
查看>>
while()
查看>>
常用限制input的方法
查看>>
Ext Js简单事件处理和对象作用域
查看>>
IIS7下使用urlrewriter.dll配置
查看>>
12.通过微信小程序端访问企查查(采集工商信息)
查看>>
WinXp 开机登录密码
查看>>
POJ 1001 Exponentiation
查看>>
HDU 4377 Sub Sequence[串构造]
查看>>
云时代架构阅读笔记之四
查看>>
WEB请求处理一:浏览器请求发起处理
查看>>