{-# LANGUAGE DeriveDataTypeable, BangPatterns #-} import qualified Data.ByteString.Lazy as BS import Data.Int import Data.List (findIndex, maximumBy) import Data.List.Split (chunksOf) import Data.Ord (comparing) import Debug.Trace (trace) import Control.DeepSeq import Numeric.LinearAlgebra import Codec.Compression.GZip import System.Random import System.Environment (getArgs) import System.Console.CmdArgs.Implicit import Network data Cost = Quadratic | CrossEntropy deriving (Show, Data) toCostFunction :: Cost -> CostFunction toCostFunction Quadratic = QuadraticCost toCostFunction CrossEntropy = CrossEntropyCost data Arguments = Arguments { eta :: Double, lambda :: Double, filePath :: FilePath, costFunction :: Cost, epochs :: Int, miniBatchSize :: Int, hiddenNeurons :: Int } deriving (Show, Data, Typeable) arguments = Arguments { eta = 0.5 &= help "Learning rate", lambda = 5 &= help "Lambda of regularization", filePath = "" &= help "Load network from file", costFunction = Quadratic &= help "Cost function", epochs = 30 &= help "Number of training epochs", miniBatchSize = 10 &= help "Mini batch size", hiddenNeurons = 30 &= help "Number of neurons in hidden layer" } &= summary "MNIST Image classifier in Haskell v1" readImages :: FilePath -> FilePath -> Int64 -> IO ([(Int, Vector R)]) readImages !imgPath !lblPath !n = do imgBytes <- fmap decompress (BS.readFile imgPath) lblBytes <- fmap decompress (BS.readFile lblPath) let !imgs = map (readImage imgBytes) [0..n-1] !lbs = map (readLabel lblBytes) [0..n-1] return $! zip lbs imgs readImage :: BS.ByteString -> Int64 -> Vector R readImage !bytes !n = vector $! map ((/256) . fromIntegral . BS.index bytes . (n*28^2 + 16 +)) [0..783] readLabel :: BS.ByteString -> Int64 -> Int readLabel bytes n = fromIntegral $! BS.index bytes (n + 8) toLabel :: Int -> Vector R toLabel n = fromList [ if i == n then 1 else 0 | i <- [0..9]] fromLabel :: Vector R -> Int fromLabel vec = snd $ maximumBy (comparing fst) (zip (toList vec) [0..]) drawImage :: Vector R -> String drawImage vec = concatMap toLine (chunksOf 28 (toList vec)) where toLine ps = (map toChar ps) ++ "\n" toChar v | v > 0.5 = 'o' | otherwise = '.' drawSample :: Sample Double -> String drawSample sample = drawImage (fst sample) ++ "Label: " ++ show (fromLabel $ snd sample) trainIms :: IO [(Int, Vector R)] trainIms = readImages "mnist-data/train-images-idx3-ubyte.gz" "mnist-data/train-labels-idx1-ubyte.gz" 50000 trainSamples :: IO (Samples Double) trainSamples = do ims <- trainIms return $! [ img --> toLabel lbl | (lbl, img) <- ims] testIms :: IO [(Int, Vector R)] testIms = readImages "mnist-data/t10k-images-idx3-ubyte.gz" "mnist-data/t10k-labels-idx1-ubyte.gz" 10000 testSamples :: IO (Samples Double) testSamples = do ims <- testIms return $! [ img --> toLabel lbl | (lbl, img) <- ims] classify :: Network Double -> Sample Double -> IO () classify net spl = do putStrLn (drawImage (fst spl) ++ "Recognized as " ++ show (fromLabel $ output net activation (fst spl))) test :: Network Double -> Samples Double -> Int test net = let f (img, lbl) = (fromLabel $ output net sigmoid img, fromLabel lbl) in hits . map f where hits :: [(Int, Int)] -> Int hits result = sum $ map (\(a, b) -> if a == b then 1 else 0) result activation :: (Floating a) => ActivationFunction a activation = sigmoid activation' :: (Floating a) => ActivationFunctionDerivative a activation' = sigmoid' main = do args <- cmdArgs arguments net <- case filePath args of "" -> newNetwork [784, hiddenNeurons args, 10] fp -> loadNetwork fp :: IO (Network Double) trSamples <- trainSamples tstSamples <- testSamples let bad = tstSamples `deepseq` test net tstSamples putStrLn $ "Initial performance of network: recognized " ++ show bad ++ " of 10k" let debug network epochs = "Left epochs: " ++ show epochs ++ " recognized: " ++ show (test network tstSamples) ++ " of 10k" smartNet <- trSamples `deepseq` (trainShuffled (epochs args) debug net (toCostFunction $ costFunction args) (lambda args) trSamples (miniBatchSize args) (eta args)) let res = test smartNet tstSamples putStrLn $ "finished testing. recognized: " ++ show res ++ " of 10k" putStrLn "saving network" saveNetwork "mnist.net" smartNet