pytorch实现线性回归

£神魔★判官ぃ 2023-02-15 12:54 112阅读 0赞

安装依赖

  1. pip install torch numpy matplotlib

编写代码

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from torch.autograd import Variable
  6. # Hyper Parameters
  7. input_size = 1
  8. output_size = 1
  9. num_epochs = 10000
  10. learning_rate = 0.001
  11. x_train = np.array([[2.3], [4.4], [3.7], [6.1], [7.3], [2.1], [5.6], [7.7], [8.7], [4.1],
  12. [6.7], [6.1], [7.5], [2.1], [7.2],
  13. [5.6], [5.7], [7.7], [3.1]], dtype=np.float32)
  14. # xtrain生成矩阵数据
  15. y_train = np.array([[3.7], [4.76], [4.], [7.1], [8.6], [3.5], [5.4], [7.6], [7.9], [5.3],
  16. [7.3], [7.5], [8.5], [3.2], [8.7],
  17. [6.4], [6.6], [7.9], [5.3]], dtype=np.float32)
  18. plt.figure()
  19. # 画图散点图
  20. plt.scatter(x_train, y_train)
  21. plt.xlabel('x_train')
  22. # x轴名称
  23. plt.ylabel('y_train')
  24. # y轴名称
  25. # 显示图片
  26. plt.show()
  27. # 线性回归模型
  28. class LinearRegression(nn.Module):
  29. def __init__(self, input_size, output_size):
  30. super(LinearRegression, self).__init__()
  31. self.linear = nn.Linear(input_size, output_size)
  32. def forward(self, x):
  33. out = self.linear(x)
  34. return out
  35. model = LinearRegression(input_size, output_size)
  36. # 损失函数
  37. criterion = nn.MSELoss()
  38. # 优化器
  39. optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
  40. # 训练模型
  41. for epoch in range(num_epochs):
  42. # 张量数据
  43. inputs = Variable(torch.from_numpy(x_train))
  44. targets = Variable(torch.from_numpy(y_train))
  45. # 前向传播
  46. optimizer.zero_grad()
  47. outputs = model(inputs)
  48. loss = criterion(outputs, targets)
  49. # 反向传播
  50. loss.backward()
  51. # 优化器
  52. optimizer.step()
  53. if (epoch + 1) % 5 == 0:
  54. print('Epoch [%d/%d], Loss: %.4f'
  55. % (epoch + 1, num_epochs, loss.data))
  56. # 绘图
  57. model.eval()
  58. predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
  59. plt.plot(x_train, y_train, 'ro')
  60. plt.plot(x_train, predicted, label='predict')
  61. plt.legend()
  62. plt.show()

运行代码

在这里插入图片描述
在这里插入图片描述

发表评论

表情:
评论列表 (有 0 条评论,112人围观)

还没有评论,来说两句吧...

相关阅读

    相关 pytorch-线性回归

    线性回归 主要内容包括: 1. 线性回归的基本要素 2. 线性回归模型从零开始的实现 3. 线性回归模型使用pytorch的简洁实现 线性回归的基本要素