pytorch:多项式回归

r囧r小猫 2022-01-31 01:53 394阅读 0赞
  1. import numpy as np
  2. import torch
  3. from torch.autograd import Variable
  4. from torch import nn, optim
  5. import matplotlib.pyplot as plt
  6. # 设置字体为中文
  7. plt.rcParams['font.sans-serif'] = ['SimHei']
  8. plt.rcParams['axes.unicode_minus'] = False
  9. # 构造成次方矩阵
  10. def make_fertures(x):
  11. x = x.unsqueeze(1)
  12. return torch.cat([x ** i for i in range(1, 4)], 1)
  13. # y = 0.9+0.5*x+3*x*x+2.4x*x*x
  14. W_target = torch.FloatTensor([0.5, 3, 2.4]).unsqueeze(1)
  15. b_target = torch.FloatTensor([0.9])
  16. # 计算x*w+b
  17. def f(x):
  18. return x.mm(W_target) + b_target.item()
  19. def get_batch(batch_size=32):
  20. random = torch.randn(batch_size)
  21. random = np.sort(random)
  22. random = torch.Tensor(random)
  23. x = make_fertures(random)
  24. y = f(x)
  25. if (torch.cuda.is_available()):
  26. return Variable(x).cuda(), Variable(y).cuda()
  27. else:
  28. return Variable(x), Variable(y)
  29. # 多项式模型
  30. class poly_model(nn.Module):
  31. def __init__(self):
  32. super(poly_model, self).__init__()
  33. self.poly = nn.Linear(3, 1) # 输入时3维,输出是1维
  34. def forward(self, x):
  35. out = self.poly(x)
  36. return out
  37. if torch.cuda.is_available():
  38. model = poly_model().cuda()
  39. else:
  40. model = poly_model()
  41. # 均方误差,随机梯度下降
  42. criterion = nn.MSELoss()
  43. optimizer = optim.SGD(model.parameters(), lr=1e-3)
  44. epoch = 0 # 统计训练次数
  45. ctn = []
  46. lo = []
  47. while True:
  48. batch_x, batch_y = get_batch()
  49. output = model(batch_x)
  50. loss = criterion(output, batch_y)
  51. print_loss = loss.item()
  52. optimizer.zero_grad()
  53. loss.backward()
  54. optimizer.step()
  55. ctn.append(epoch)
  56. lo.append(print_loss)
  57. epoch += 1
  58. if (print_loss < 1e-3):
  59. break
  60. print("Loss: {:.6f} after {} batches".format(loss.item(), epoch))
  61. print(
  62. "==> Learned function: y = {:.2f} + {:.2f}*x + {:.2f}*x^2 + {:.2f}*x^3".format(model.poly.bias[0], model.poly.weight[0][0],
  63. model.poly.weight[0][1],
  64. model.poly.weight[0][2]))
  65. print("==> Actual function: y = {:.2f} + {:.2f}*x + {:.2f}*x^2 + {:.2f}*x^3".format(b_target[0], W_target[0][0],
  66. W_target[1][0], W_target[2][0]))
  67. # 1.可视化真实数据
  68. predict = model(batch_x)
  69. x = batch_x.numpy()[:, 0] # x~1 x~2 x~3
  70. plt.plot(x, batch_y.numpy(), 'ro')
  71. plt.title(label='可视化真实数据')
  72. plt.show()
  73. # 2.可视化拟合函数
  74. predict = predict.data.numpy()
  75. plt.plot(x, predict, 'b')
  76. plt.plot(x, batch_y.numpy(), 'ro')
  77. plt.title(label='可视化拟合函数')
  78. plt.show()
  79. # 3.可视化训练次数和损失
  80. plt.plot(ctn,lo)
  81. plt.xlabel('训练次数')
  82. plt.ylabel('损失值')
  83. plt.title(label='训练次数与损失关系')
  84. plt.show()

实验结果:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Nvbmd4aWFvbGluZ2Jhb2Jhbw_size_16_color_FFFFFF_t_70

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Nvbmd4aWFvbGluZ2Jhb2Jhbw_size_16_color_FFFFFF_t_70 1

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Nvbmd4aWFvbGluZ2Jhb2Jhbw_size_16_color_FFFFFF_t_70 2

注意:批量产生数据后,进行一个排序,否则可视化时,不是按照x轴从小到大绘制,出现很多折线。对应代码:

  1. random = np.sort(random)
  2. random = torch.Tensor(random)

发表评论

表情:
评论列表 (有 0 条评论,394人围观)

还没有评论,来说两句吧...

相关阅读

    相关 pytorch-线性回归

    线性回归 主要内容包括: 1. 线性回归的基本要素 2. 线性回归模型从零开始的实现 3. 线性回归模型使用pytorch的简洁实现 线性回归的基本要素