TF estimator API in R
We learned about the TensorFlow estimator API in Chapter 2. In R, this API is implemented with the tfestimator
R package.
As an example, we provide a walkthrough of the MLP Model for classifying handwritten digits from the MNIST dataset at the following link: https://tensorflow.rstudio.com/tfestimators/articles/examples/mnist.html.
Note
You can follow along with the code in the Jupyter R notebook ch-17b_TFEstimator_in_R
.
- First, load the libraries:
library(tensorflow) library(tfestimators)
- Define the hyper-parameters:
batch_size <- 128 n_classes <- 10 n_steps <- 100
- Prepare the data:
# initialize data directory data_dir <- "~/datasets/mnist" dir.create(data_dir, recursive = TRUE, showWarnings = FALSE) # download the MNIST data sets, and read them into R sources <- list( train = list( x = "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", y = "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz...