MNIST

MNIST

From the model zoo

using Flux, MNIST, CuArrays
using Flux: onehotbatch, argmax, mse, throttle
using Base.Iterators: repeated

x, y = traindata()
y = onehotbatch(y, 0:9)

m = Chain(
    Dense(28^2, 32, σ),
    Dense(32, 10),
    softmax
)

using CuArrays
# or CLArrays (you then need to use cl

x, y = cu(x), cu(y)
m = mapparams(cu, m)
loss(x, y) = mse(m(x), y)

dataset = repeated((x, y), 500)
evalcb = () -> @show(loss(x, y))
opt = SGD(params(m), 1)

Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))

# Check the prediction for the first digit
argmax(m(x[:,1]), 0:9) == argmax(y[:,1], 0:9)