pytorch:Logistic回归

一时失言乱红尘 2022-01-31 07:47 495阅读 0赞
  1. import matplotlib.pyplot as plt
  2. from torch import nn, optim
  3. import numpy as np
  4. import torch
  5. from torch.autograd import Variable
  6. # 设置字体为中文
  7. plt.rcParams['font.sans-serif'] = ['SimHei']
  8. plt.rcParams['axes.unicode_minus'] = False
  9. # 读取数据
  10. def readText():
  11. with open('data.txt', 'r') as f:
  12. data_list = f.readlines()
  13. data_list = [i.split('\n')[0] for i in data_list]
  14. data_list = [i.split(',') for i in data_list]
  15. data = [(float(i[0]), float(i[1]), int(i[2])) for i in data_list]
  16. x_data = [[float(i[0]), float(i[1])] for i in data_list]
  17. y_data = [float(i[2]) for i in data_list]
  18. return data, x_data, y_data
  19. # 原始数据可视化
  20. def visualize(data):
  21. x0 = list(filter(lambda x: x[-1] == 0.0, data)) # 找出类别为0的数据集
  22. x1 = list(filter(lambda x: x[-1] == 1.0, data)) # 找出类别为1的数据集
  23. plot_x0_0 = [i[0] for i in x0] # 类别0的x
  24. plot_x0_1 = [i[1] for i in x0] # 类别0的y
  25. plot_x1_0 = [i[0] for i in x1] # 类别1的x
  26. plot_x1_1 = [i[1] for i in x1] # 类别1的y
  27. plt.plot(plot_x0_0, plot_x0_1, 'ro', label='类别 0')
  28. plt.plot(plot_x1_0, plot_x1_1, 'bo', label='类别 1')
  29. plt.legend() # 显示图例
  30. plt.title(label='原始数据分布情况')
  31. plt.show()
  32. def visualize_after(data):
  33. x0 = list(filter(lambda x: x[-1] == 0.0, data)) # 找出类别为0的数据集
  34. x1 = list(filter(lambda x: x[-1] == 1.0, data)) # 找出类别为1的数据集
  35. plot_x0_0 = [i[0] for i in x0] # 类别0的x
  36. plot_x0_1 = [i[1] for i in x0] # 类别0的y
  37. plot_x1_0 = [i[0] for i in x1] # 类别1的x
  38. plot_x1_1 = [i[1] for i in x1] # 类别1的y
  39. plt.plot(plot_x0_0, plot_x0_1, 'ro', label='类别 0')
  40. plt.plot(plot_x1_0, plot_x1_1, 'bo', label='类别 1')
  41. # 绘制分类函数
  42. w0, w1 = model.lr.weight[0]
  43. w0 = w0.data.item()
  44. w1 = w1.data.item()
  45. b = model.lr.bias.item()
  46. plot_x = np.arange(30, 100, 0.1)
  47. plot_y = (-w0 * plot_x - b) / w1
  48. plt.plot(plot_x, plot_y, 'yo', label='分类线')
  49. plt.legend() # 显示图例
  50. plt.title(label='分类线可视化')
  51. plt.show()
  52. # Logistic回归模型
  53. class LogisticRegression(nn.Module):
  54. def __init__(self):
  55. super(LogisticRegression, self).__init__()
  56. self.lr = nn.Linear(2, 1)
  57. self.sm = nn.Sigmoid() # 激活函数类型为sigmoid,经过激活函数,值控制在0到1之间
  58. def forward(self, x):
  59. x = self.lr(x)
  60. x = self.sm(x)
  61. return x
  62. # 主函数
  63. if __name__ == '__main__':
  64. data, x_data, y_data = readText() # 读取数据
  65. x_data = torch.from_numpy(np.array(x_data))
  66. y_data = torch.from_numpy(np.array(y_data))
  67. visualize(data) # 可视化(观察原始数据分布情况)
  68. # 获取到模型
  69. model = LogisticRegression()
  70. if torch.cuda.is_available():
  71. model.cuda()
  72. # 损失函数以及梯度下降
  73. criterion = nn.BCELoss() # 二分类的损失函数
  74. optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
  75. ctn = []
  76. lo = []
  77. for epoch in range(50000):
  78. if torch.cuda.is_available():
  79. x = Variable(x_data).cuda()
  80. y = Variable(y_data).cuda()
  81. else:
  82. x = Variable(x_data)
  83. y = Variable(y_data)
  84. x = torch.tensor(x, dtype=torch.float32)
  85. out = model(x)
  86. y = torch.tensor(y, dtype=torch.float32)
  87. loss = criterion(out, y)
  88. print_loss = loss.data.item()
  89. mask = out.ge(0.5).float() # 大于0.5则输出为1
  90. correct = (mask == y).sum() # 统计输出为1的个数
  91. acc = correct.item() / x.size(0)
  92. optimizer.zero_grad() # 梯度清零
  93. loss.backward() # 反向传播
  94. optimizer.step() # 更新梯度
  95. ctn.append(epoch+1)
  96. lo.append(print_loss)
  97. if (epoch+1) % 1000 == 0:
  98. print('*'*10)
  99. print('epoch {}'.format(epoch+1))
  100. print('loss is {:.4f}'.format(print_loss))
  101. print('acc is {:.4f}'.format(acc))
  102. visualize_after(data)
  103. # 绘制训练次数与损失值之间的关系
  104. plt.plot(ctn,lo)
  105. plt.title(label='训练次数与损失值关系')
  106. plt.xlabel('训练次数')
  107. plt.ylabel('损失值')
  108. plt.show()

实验数据:

  1. 34.62365962451697,78.0246928153624,0
  2. 30.2867107622687,43.89499752400101,0
  3. 35.84740876993872,72.90219802708364,0
  4. 60.18259938620976,86.3855209546826,1
  5. 79.0327360507101,75.3443764369103,1
  6. 45.08327747668339,56.3163717815305,0
  7. 61.10666453684766,96.51142588489624,1
  8. 75.02474556738889,46.55401354116538,1
  9. 76.09878670226257,87.42056971926803,1
  10. 84.43281996120035,43.53339331072109,1
  11. 95.86155507093572,38.22527805795094,0
  12. 75.01365838958247,30.60326323428011,0
  13. 82.30705337399482,76.48196330235604,1
  14. 69.36458875970939,97.71869196188608,1
  15. 39.53833914367223,76.03681085115882,0
  16. 53.9710521485623,89.20735013750265,1
  17. 69.07014406283025,52.74046973016765,1
  18. 67.9468554771161746,67.857410673128,0

实验结果:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Nvbmd4aWFvbGluZ2Jhb2Jhbw_size_16_color_FFFFFF_t_70watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Nvbmd4aWFvbGluZ2Jhb2Jhbw_size_16_color_FFFFFF_t_70 1watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Nvbmd4aWFvbGluZ2Jhb2Jhbw_size_16_color_FFFFFF_t_70 2

注意事项:

(1)数据太少,没有找到完整数据。

(2)注意类型转换,float64转为float32,对应代码为:

  1. x = torch.tensor(x, dtype=torch.float32)
  2. out = model(x)
  3. y = torch.tensor(y, dtype=torch.float32)

发表评论

表情:
评论列表 (有 0 条评论,495人围观)

还没有评论,来说两句吧...

相关阅读