tensorflow MNIST机器学习入门

男娘i 2022-05-28 02:08 346阅读 0赞

tensorflow MNIST机器学习入门

tensorflow入门学习官方网站

MNIST数据集

MNIST数据集的官网是Yann LeCun’s website。在这里,我们提供了一份python源代码用于自动下载和安装这个数据集。你可以下载这份代码,然后用下面的代码导入到你的项目里面,也可以直接复制粘贴到你的代码文件里面。

  1. import input_data
  2. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  • 1
  • 2

下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。

每一张图片包含28像素X28像素。我们可以用一个数字数组来表示这张图片:

我们把这个数组展开成一个向量,长度是 28x28 = 784

因此,在MNIST训练数据集中,mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。

相对应的MNIST数据集的标签是介于0到9的数字,用来描述给定图片里表示的数字.因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。


Softmax回归

softmax模型可以用来给不同的对象分配概率。

分两步

  • 我们对图片像素值进行加权求和。如果这个像素具有很强的证据说明这张图片不属于该类,那么相应的权值为负数,相反如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值是正数。
  • 将各像素点权值累加归一化处理,是和为1,满足概率分布。

训练模型

使用成本函数是“交叉熵”(cross-entropy)。

在这里,我们要求TensorFlow用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。

评估模型

计算模型正确率。

code

  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # File Name: start.py
  4. # Author: Fang Pin
  5. # Mail: fangpin1993@hotmail.com
  6. # Created Time: 一 6/27 20:33:46 2016
  7. import input_data
  8. import tensorflow as tf
  9. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  10. # soft回归模型
  11. x = tf.placeholder("float", [None, 784])
  12. W = tf.Variable(tf.zeros([784,10]))
  13. b = tf.Variable(tf.zeros([10]))
  14. y = tf.nn.softmax(tf.matmul(x,W) + b)
  15. # 训练模型
  16. y_ = tf.placeholder("float", [None,10])
  17. cross_entropy = -tf.reduce_sum(y_*tf.log(y))
  18. train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
  19. init = tf.initialize_all_variables()
  20. sess = tf.Session()
  21. sess.run(init)
  22. for i in range(1000):
  23. batch_xs, batch_ys = mnist.train.next_batch(100)
  24. sess.run(train_step, feed_dict={
  25. x: batch_xs, y_: batch_ys})
  26. # 评估模型
  27. correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
  28. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
  29. print sess.run(accuracy, feed_dict={
  30. x: mnist.test.images, y_: mnist.test.labels})
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

    Copyright 2015 Google Inc. All Rights Reserved.

    #

    Licensed under the Apache License, Version 2.0 (the “License”);

    you may not use this file except in compliance with the License.

    You may obtain a copy of the License at

    #

    http://www.apache.org/licenses/LICENSE-2.0

    #

    Unless required by applicable law or agreed to in writing, software

    distributed under the License is distributed on an “AS IS” BASIS,

    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

    See the License for the specific language governing permissions and

    limitations under the License.

    ==============================================================================

    “””Functions for downloading and reading MNIST data.”””
    from future import absoluteimport
    from future import division
    from _future
    import print_function
    import gzip
    import os
    import tensorflow.python.platform
    import numpy
    from six.moves import urllib
    from six.moves import xrange # pylint: disable=redefined-builtin
    import tensorflow as tf
    SOURCE_URL = ‘http://yann.lecun.com/exdb/mnist/‘
    def maybe_download(filename, work_directory):
    “””Download the data from Yann’s website, unless it’s already here.”””
    if not os.path.exists(work_directory):

    1. os.mkdir(work_directory)

    filepath = os.path.join(work_directory, filename)
    if not os.path.exists(filepath):

    1. filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
    2. statinfo = os.stat(filepath)
    3. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')

    return filepath
    def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder(‘>’)
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
    def extract_images(filename):
    “””Extract the images into a 4D uint8 numpy array [index, y, x, depth].”””
    print(‘Extracting’, filename)
    with gzip.open(filename) as bytestream:

    1. magic = _read32(bytestream)
    2. if magic != 2051:
    3. raise ValueError(
    4. 'Invalid magic number %d in MNIST image file: %s' %
    5. (magic, filename))
    6. num_images = _read32(bytestream)
    7. rows = _read32(bytestream)
    8. cols = _read32(bytestream)
    9. buf = bytestream.read(rows * cols * num_images)
    10. data = numpy.frombuffer(buf, dtype=numpy.uint8)
    11. data = data.reshape(num_images, rows, cols, 1)
    12. return data

    def dense_to_one_hot(labels_dense, num_classes=10):
    “””Convert class labels from scalars to one-hot vectors.”””
    num_labels = labels_dense.shape[0]
    index_offset = numpy.arange(num_labels) * num_classes
    labels_one_hot = numpy.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot
    def extract_labels(filename, one_hot=False):
    “””Extract the labels into a 1D uint8 numpy array [index].”””
    print(‘Extracting’, filename)
    with gzip.open(filename) as bytestream:

    1. magic = _read32(bytestream)
    2. if magic != 2049:
    3. raise ValueError(
    4. 'Invalid magic number %d in MNIST label file: %s' %
    5. (magic, filename))
    6. num_items = _read32(bytestream)
    7. buf = bytestream.read(num_items)
    8. labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    9. if one_hot:
    10. return dense_to_one_hot(labels)
    11. return labels

    class DataSet(object):
    def init(self, images, labels, fake_data=False, one_hot=False, dtype=tf.float32):

    1. """Construct a DataSet. one_hot arg is used only if fake_data is true. `dtype` can be either `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into `[0, 1]`. """
    2. dtype = tf.as_dtype(dtype).base_dtype
    3. if dtype not in (tf.uint8, tf.float32):
    4. raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
    5. dtype)
    6. if fake_data:
    7. self._num_examples = 10000
    8. self.one_hot = one_hot
    9. else:
    10. assert images.shape[0] == labels.shape[0], (
    11. 'images.shape: %s labels.shape: %s' % (images.shape,
    12. labels.shape))
    13. self._num_examples = images.shape[0]
    14. # Convert shape from [num examples, rows, columns, depth]
    15. # to [num examples, rows*columns] (assuming depth == 1)
    16. assert images.shape[3] == 1
    17. images = images.reshape(images.shape[0],
    18. images.shape[1] * images.shape[2])
    19. if dtype == tf.float32:
    20. # Convert from [0, 255] -> [0.0, 1.0].
    21. images = images.astype(numpy.float32)
    22. images = numpy.multiply(images, 1.0 / 255.0)
    23. self._images = images
    24. self._labels = labels
    25. self._epochs_completed = 0
    26. self._index_in_epoch = 0

    @property
    def images(self):

    1. return self._images

    @property
    def labels(self):

    1. return self._labels

    @property
    def num_examples(self):

    1. return self._num_examples

    @property
    def epochs_completed(self):

    1. return self._epochs_completed

    def next_batch(self, batch_size, fake_data=False):

    1. """Return the next `batch_size` examples from this data set."""
    2. if fake_data:
    3. fake_image = [1] * 784
    4. if self.one_hot:
    5. fake_label = [1] + [0] * 9
    6. else:
    7. fake_label = 0
    8. return [fake_image for _ in xrange(batch_size)], [
    9. fake_label for _ in xrange(batch_size)]
    10. start = self._index_in_epoch
    11. self._index_in_epoch += batch_size
    12. if self._index_in_epoch > self._num_examples:
    13. # Finished epoch
    14. self._epochs_completed += 1
    15. # Shuffle the data
    16. perm = numpy.arange(self._num_examples)
    17. numpy.random.shuffle(perm)
    18. self._images = self._images[perm]
    19. self._labels = self._labels[perm]
    20. # Start next epoch
    21. start = 0
    22. self._index_in_epoch = batch_size
    23. assert batch_size <= self._num_examples
    24. end = self._index_in_epoch
    25. return self._images[start:end], self._labels[start:end]

    def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
    class DataSets(object):

    1. pass

    data_sets = DataSets()
    if fake_data:

    1. def fake():
    2. return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
    3. data_sets.train = fake()
    4. data_sets.validation = fake()
    5. data_sets.test = fake()
    6. return data_sets

    TRAIN_IMAGES = ‘train-images-idx3-ubyte.gz’
    TRAIN_LABELS = ‘train-labels-idx1-ubyte.gz’
    TEST_IMAGES = ‘t10k-images-idx3-ubyte.gz’
    TEST_LABELS = ‘t10k-labels-idx1-ubyte.gz’
    VALIDATION_SIZE = 5000
    local_file = maybe_download(TRAIN_IMAGES, train_dir)
    train_images = extract_images(local_file)
    local_file = maybe_download(TRAIN_LABELS, train_dir)
    train_labels = extract_labels(local_file, one_hot=one_hot)
    local_file = maybe_download(TEST_IMAGES, train_dir)
    test_images = extract_images(local_file)
    local_file = maybe_download(TEST_LABELS, train_dir)
    test_labels = extract_labels(local_file, one_hot=one_hot)
    validation_images = train_images[:VALIDATION_SIZE]
    validation_labels = train_labels[:VALIDATION_SIZE]
    train_images = train_images[VALIDATION_SIZE:]
    train_labels = train_labels[VALIDATION_SIZE:]
    data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
    data_sets.validation = DataSet(validation_images, validation_labels,

    1. dtype=dtype)

    data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
    return data_sets

发表评论

表情:
评论列表 (有 0 条评论,346人围观)

还没有评论,来说两句吧...

相关阅读

    相关 TensorFlow MNIST案例代码

    分享一个我的公众号,最近突然想玩公众号,之前做过一段时间前端开发,考虑到现在应用程序越来越多,未来社会一定是一个充满“只有你想不到,没有你做不到”的App的世界!而微信小程序又