|
- {-# LANGUAGE DeriveDataTypeable, BangPatterns #-}
-
- import qualified Data.ByteString.Lazy as BS
-
- import Data.Int
- import Data.List (findIndex, maximumBy)
- import Data.List.Split (chunksOf)
- import Data.Ord (comparing)
- import Debug.Trace (trace)
-
- import Control.DeepSeq
-
- import Numeric.LinearAlgebra
- import Codec.Compression.GZip
- import System.Random
- import System.Environment (getArgs)
- import System.Console.CmdArgs.Implicit
-
- import Network
-
- data Cost = Quadratic | CrossEntropy
- deriving (Show, Data)
-
- toCostFunction :: Cost -> CostFunction
- toCostFunction Quadratic = QuadraticCost
- toCostFunction CrossEntropy = CrossEntropyCost
-
- data Arguments = Arguments { eta :: Double, lambda :: Double,
- filePath :: FilePath, costFunction :: Cost,
- epochs :: Int, miniBatchSize :: Int,
- hiddenNeurons :: Int }
- deriving (Show, Data, Typeable)
-
- arguments = Arguments { eta = 0.5 &= help "Learning rate",
- lambda = 5 &= help "Lambda of regularization",
- filePath = "" &= help "Load network from file",
- costFunction = Quadratic &= help "Cost function",
- epochs = 30 &= help "Number of training epochs",
- miniBatchSize = 10 &= help "Mini batch size",
- hiddenNeurons = 30 &= help "Number of neurons in hidden layer" }
- &= summary "MNIST Image classifier in Haskell v1"
-
- readImages :: FilePath -> FilePath -> Int64 -> IO ([(Int, Vector R)])
- readImages !imgPath !lblPath !n = do
- imgBytes <- fmap decompress (BS.readFile imgPath)
- lblBytes <- fmap decompress (BS.readFile lblPath)
- let !imgs = map (readImage imgBytes) [0..n-1]
- !lbs = map (readLabel lblBytes) [0..n-1]
- return $! zip lbs imgs
-
- readImage :: BS.ByteString -> Int64 -> Vector R
- readImage !bytes !n = vector $! map ((/256) . fromIntegral . BS.index bytes . (n*28^2 + 16 +)) [0..783]
-
- readLabel :: BS.ByteString -> Int64 -> Int
- readLabel bytes n = fromIntegral $! BS.index bytes (n + 8)
-
- toLabel :: Int -> Vector R
- toLabel n = fromList [ if i == n then 1 else 0 | i <- [0..9]]
-
- fromLabel :: Vector R -> Int
- fromLabel vec = snd $ maximumBy (comparing fst) (zip (toList vec) [0..])
-
- drawImage :: Vector R -> String
- drawImage vec = concatMap toLine (chunksOf 28 (toList vec))
- where toLine ps = (map toChar ps) ++ "\n"
- toChar v
- | v > 0.5 = 'o'
- | otherwise = '.'
-
- drawSample :: Sample Double -> String
- drawSample sample = drawImage (fst sample)
- ++ "Label: "
- ++ show (fromLabel $ snd sample)
-
- trainIms :: IO [(Int, Vector R)]
- trainIms = readImages "mnist-data/train-images-idx3-ubyte.gz"
- "mnist-data/train-labels-idx1-ubyte.gz"
- 50000
-
- trainSamples :: IO (Samples Double)
- trainSamples = do
- ims <- trainIms
- return $! [ img --> toLabel lbl | (lbl, img) <- ims]
-
- testIms :: IO [(Int, Vector R)]
- testIms = readImages "mnist-data/t10k-images-idx3-ubyte.gz"
- "mnist-data/t10k-labels-idx1-ubyte.gz"
- 10000
-
- testSamples :: IO (Samples Double)
- testSamples = do
- ims <- testIms
- return $! [ img --> toLabel lbl | (lbl, img) <- ims]
-
- classify :: Network Double -> Sample Double -> IO ()
- classify net spl = do
- putStrLn (drawImage (fst spl)
- ++ "Recognized as "
- ++ show (fromLabel $ output net activation (fst spl)))
-
- test :: Network Double -> Samples Double -> Int
- test net = let f (img, lbl) = (fromLabel $ output net sigmoid img, fromLabel lbl) in
- hits . map f
- where hits :: [(Int, Int)] -> Int
- hits result = sum $ map (\(a, b) -> if a == b then 1 else 0) result
-
- activation :: (Floating a) => ActivationFunction a
- activation = sigmoid
-
- activation' :: (Floating a) => ActivationFunctionDerivative a
- activation' = sigmoid'
-
- main = do
- args <- cmdArgs arguments
- net <- case filePath args of
- "" -> newNetwork [784, hiddenNeurons args, 10]
- fp -> loadNetwork fp :: IO (Network Double)
- trSamples <- trainSamples
- tstSamples <- testSamples
- let bad = tstSamples `deepseq` test net tstSamples
- putStrLn $ "Initial performance of network: recognized " ++ show bad ++ " of 10k"
- let debug network epochs = "Left epochs: " ++ show epochs
- ++ " recognized: " ++ show (test network tstSamples) ++ " of 10k"
- smartNet <- trSamples `deepseq` (trainShuffled
- (epochs args) debug net
- (toCostFunction $ costFunction args)
- (lambda args) trSamples
- (miniBatchSize args)
- (eta args))
- let res = test smartNet tstSamples
- putStrLn $ "finished testing. recognized: " ++ show res ++ " of 10k"
- putStrLn "saving network"
- saveNetwork "mnist.net" smartNet
|