{-# LANGUAGE DeriveDataTypeable, BangPatterns #-} import qualified Data.ByteString.Lazy as BS import Data.Int import Data.Binary import Data.List (findIndex, maximumBy) import Data.List.Split (chunksOf) import Data.Ord (comparing) import Debug.Trace (trace) import Codec.Compression.GZip import qualified Data.ByteString.Lazy as B import Control.DeepSeq import Numeric.LinearAlgebra import System.Random import System.Environment (getArgs) import System.Console.CmdArgs.Implicit import Network data Cost = Quadratic | CrossEntropy deriving (Show, Data) toCostFunction :: Cost -> CostFunction toCostFunction Quadratic = QuadraticCost toCostFunction CrossEntropy = CrossEntropyCost data Arguments = Arguments { eta :: Double, lambda :: Double, filePath :: FilePath, costFunction :: Cost, epochs :: Int, miniBatchSize :: Int, hiddenNeurons :: Int } deriving (Show, Data, Typeable) arguments = Arguments { eta = 0.5 &= help "Learning rate", lambda = 5 &= help "Lambda of regularization", filePath = "" &= help "Load network from file", costFunction = CrossEntropy &= help "Cost function", epochs = 30 &= help "Number of training epochs", miniBatchSize = 10 &= help "Mini batch size", hiddenNeurons = 30 &= help "Number of neurons in hidden layer" } &= summary "MNIST Image classifier in Haskell v1" readImages :: FilePath -> FilePath -> Int64 -> IO ([(Int, Vector R)]) readImages !imgPath !lblPath !n = do imgBytes <- fmap decompress (BS.readFile imgPath) lblBytes <- fmap decompress (BS.readFile lblPath) let !imgs = map (readImage imgBytes) [0..n-1] !lbs = map (readLabel lblBytes) [0..n-1] return $! zip lbs imgs readImage :: BS.ByteString -> Int64 -> Vector R readImage !bytes !n = vector $! map ((/256) . fromIntegral . BS.index bytes . (n*28^2 + 16 +)) [0..783] readLabel :: BS.ByteString -> Int64 -> Int readLabel bytes n = fromIntegral $! BS.index bytes (n + 8) toLabel :: Int -> Vector R toLabel n = fromList [ if i == n then 1 else 0 | i <- [0..9]] fromLabel :: Vector R -> Int fromLabel vec = snd $ maximumBy (comparing fst) (zip (toList vec) [0..]) drawImage :: Vector R -> String drawImage vec = concatMap toLine (chunksOf 28 (toList vec)) where toLine ps = (map toChar ps) ++ "\n" toChar v | v > 0.5 = 'o' | otherwise = '.' drawSample :: Sample Double -> String drawSample sample = drawImage (fst sample) ++ "Label: " ++ show (fromLabel $ snd sample) trainIms :: IO [(Int, Vector R)] trainIms = readImages "mnist-data/train-images-idx3-ubyte.gz" "mnist-data/train-labels-idx1-ubyte.gz" 50000 trainSamples :: IO (Samples Double) trainSamples = do ims <- trainIms return $! [ img --> toLabel lbl | (lbl, img) <- ims] testIms :: IO [(Int, Vector R)] testIms = readImages "mnist-data/t10k-images-idx3-ubyte.gz" "mnist-data/t10k-labels-idx1-ubyte.gz" 10000 testSamples :: IO (Samples Double) testSamples = do ims <- testIms return $! [ img --> toLabel lbl | (lbl, img) <- ims] newtype StorableSample a = StorableSample { getSample :: Sample a } instance (Element a, Binary a) => Binary (StorableSample a) where put (StorableSample (vecIn, vecOut)) = do put (toList vecIn) put (toList vecOut) get = do vecIn <- get vecOut <- get return $ StorableSample (fromList vecIn --> fromList vecOut) loadSamples :: FilePath -> IO (Samples Double) loadSamples fp = fmap (map getSample . decode . decompress) (B.readFile fp) saveSamples :: FilePath -> Samples Double -> IO () saveSamples fp = B.writeFile fp . compress . encode . map StorableSample classify :: Network Double -> Sample Double -> IO () classify net spl = do putStrLn (drawImage (fst spl) ++ "Recognized as " ++ show (fromLabel $ output net activation (fst spl))) test :: Network Double -> Samples Double -> Int test net = let f (img, lbl) = (fromLabel $ output net sigmoid img, fromLabel lbl) in hits . map f where hits :: [(Int, Int)] -> Int hits result = sum $ map (\(a, b) -> if a == b then 1 else 0) result activation :: (Floating a) => ActivationFunction a activation = sigmoid activation' :: (Floating a) => ActivationFunctionDerivative a activation' = sigmoid' main = do args <- cmdArgs arguments net <- case filePath args of "" -> newNetwork [784, hiddenNeurons args, 10] fp -> loadNetwork fp :: IO (Network Double) trSamples <- loadSamples "mnist-data/trainSamples.spls" tstSamples <- loadSamples "mnist-data/testSamples.spls" let bad = tstSamples `deepseq` test net tstSamples putStrLn $ "Initial performance of network: recognized " ++ show bad ++ " of 10k" let debug network epochs = "Left epochs: " ++ show epochs ++ " recognized: " ++ show (test network tstSamples) ++ " of 10k" smartNet <- trSamples `deepseq` (trainShuffled (epochs args) debug net (toCostFunction $ costFunction args) (lambda args) trSamples (miniBatchSize args) (eta args)) let res = test smartNet tstSamples putStrLn $ "finished testing. recognized: " ++ show res ++ " of 10k" putStrLn "saving network" saveNetwork "mnist.net" smartNet