python cart算法的简单实现

太过爱你忘了你带给我的痛 2021-09-26 06:50 395阅读 0赞

下面是python cart算法的简单实现,可以直接复制下面代码进行运行,即可查看模型的拟合曲线

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from sklearn.tree import DecisionTreeRegressor
  4. def plotfigure(X,X_test,y,yp):
  5. plt.figure()
  6. plt.scatter(X,y,c="k",label="data") #scatter must be 1D (cannot above 2D, for example (200,1))
  7. plt.plot(X_test,yp,c="r",label="max_depth=5",linewidth=2)
  8. plt.xlabel("data")
  9. plt.ylabel("target")
  10. plt.title("Decision Tree Regression")
  11. plt.legend()
  12. plt.show()
  13. x = np.linspace(-5,5,200)
  14. siny = np.sin(x)
  15. X = np.mat(x).T
  16. y = siny+np.random.rand(1,len(siny))*1.5
  17. y= y.tolist()[0]
  18. clf = DecisionTreeRegressor(max_depth=5,min_samples_leaf=10,min_samples_split=10)
  19. clf.fit(X,y)
  20. X_test = np.arange(-5.0,5.0,0.05)[:,np.newaxis]
  21. yp = clf.predict(X_test)
  22. plotfigure(np.array(X)[:,0],X_test,y,yp)
  23. print(X.shape,type(X))

发表评论

表情:
评论列表 (有 0 条评论,395人围观)

还没有评论,来说两句吧...

相关阅读

    相关 十个用Python实现简单算法

    一、算法题目:有1、2、3、4个数字,能组成多少个互不相同且无重复数字的三位数?都是多少? 程序分析:可填在百位、十位、个位的数字都是1、2、3、4。组成所有的排列后再去 掉