统计学习(第二章)李航 感知机

曾经终败给现在 2023-02-11 10:13 141阅读 0赞

1.

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn.datasets import load_iris
  4. import pandas as pd
  5. #数据集加载
  6. iris = load_iris()
  7. df = pd.DataFrame(iris.data, columns=iris.feature_names)
  8. df['label'] = iris.target
  9. df.columns = [
  10. 'sepal length', 'sepal width', 'petal length', 'petal width', 'label'
  11. ]
  12. print(df.label.value_counts())
  13. plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
  14. plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
  15. plt.xlabel('sepal length')
  16. plt.ylabel('sepal width')
  17. plt.legend()
  18. plt.show()
  19. data = np.array(df.iloc[:100, [0, 1, -1]])
  20. x, y = data[:,:-1], data[:,-1]
  21. y = np.array([1 if i==1 else -1 for i in y])
  22. '感知机'
  23. #数据线性可分,二分类
  24. class Model:
  25. def __init__(self):
  26. self.w = np.ones(len(data[0])-1, dtype=np.float32)
  27. self.b = 0
  28. self.l_rate = 0.1
  29. def sign(self, x, w, b):
  30. return np.dot(w, x) + b
  31. def fit(self, x_train, y_train):#随机梯度下降
  32. is_wrong = False
  33. while not is_wrong:
  34. wrong_count = 0
  35. for i in range(len(x_train)):
  36. x = x_train[i]
  37. y = y_train[i]
  38. if y * self.sign(x, self.w, self.b) <= 0:#感知机损失函数L(w,b) = -∑y*(wx+b)
  39. #w,b参数更新
  40. self.w = self.w + self.l_rate * np.dot(y, x)
  41. self.b = self.b + self.l_rate * y
  42. wrong_count += 1
  43. if wrong_count == 0:
  44. is_wrong = True
  45. return 'Perceptron Model'
  46. def score(self):
  47. pass
  48. perceptron = Model()
  49. print(perceptron.fit(x, y))
  50. x_points = np.linspace(4, 7, 10)
  51. y_ = -(perceptron.w[0] * x_points + perceptron.b) / perceptron.w[1]#各分类点到超平面的距离
  52. plt.plot(x_points, y_)
  53. plt.plot(data[:50, 0], data[:50, 1], 'bo', color='blue', label='0')
  54. plt.plot(data[50:100, 0], data[50:100, 1], 'bo', color='orange', label='1')
  55. plt.xlabel('sepal length')
  56. plt.ylabel('sepal width')
  57. plt.legend()
  58. plt.show()

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5OTM4NjY2_size_16_color_FFFFFF_t_70

发表评论

表情:
评论列表 (有 0 条评论,141人围观)

还没有评论,来说两句吧...

相关阅读