Haskell Artificial Neural Networking library
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

154 rindas
5.4KB

  1. {-# LANGUAGE DeriveDataTypeable, BangPatterns #-}
  2. import qualified Data.ByteString.Lazy as BS
  3. import Data.Int
  4. import Data.Binary
  5. import Data.List (findIndex, maximumBy)
  6. import Data.List.Split (chunksOf)
  7. import Data.Ord (comparing)
  8. import Debug.Trace (trace)
  9. import Codec.Compression.GZip
  10. import qualified Data.ByteString.Lazy as B
  11. import Control.DeepSeq
  12. import Numeric.LinearAlgebra
  13. import System.Random
  14. import System.Environment (getArgs)
  15. import System.Console.CmdArgs.Implicit
  16. import Network
  17. data Cost = Quadratic | CrossEntropy
  18. deriving (Show, Data)
  19. toCostFunction :: Cost -> CostFunction
  20. toCostFunction Quadratic = QuadraticCost
  21. toCostFunction CrossEntropy = CrossEntropyCost
  22. data Arguments = Arguments { eta :: Double, lambda :: Double,
  23. filePath :: FilePath, costFunction :: Cost,
  24. epochs :: Int, miniBatchSize :: Int,
  25. hiddenNeurons :: Int }
  26. deriving (Show, Data, Typeable)
  27. arguments = Arguments { eta = 0.5 &= help "Learning rate",
  28. lambda = 5 &= help "Lambda of regularization",
  29. filePath = "" &= help "Load network from file",
  30. costFunction = CrossEntropy &= help "Cost function",
  31. epochs = 30 &= help "Number of training epochs",
  32. miniBatchSize = 10 &= help "Mini batch size",
  33. hiddenNeurons = 30 &= help "Number of neurons in hidden layer" }
  34. &= summary "MNIST Image classifier in Haskell v1"
  35. readImages :: FilePath -> FilePath -> Int64 -> IO ([(Int, Vector R)])
  36. readImages !imgPath !lblPath !n = do
  37. imgBytes <- fmap decompress (BS.readFile imgPath)
  38. lblBytes <- fmap decompress (BS.readFile lblPath)
  39. let !imgs = map (readImage imgBytes) [0..n-1]
  40. !lbs = map (readLabel lblBytes) [0..n-1]
  41. return $! zip lbs imgs
  42. readImage :: BS.ByteString -> Int64 -> Vector R
  43. readImage !bytes !n = vector $! map ((/256) . fromIntegral . BS.index bytes . (n*28^2 + 16 +)) [0..783]
  44. readLabel :: BS.ByteString -> Int64 -> Int
  45. readLabel bytes n = fromIntegral $! BS.index bytes (n + 8)
  46. toLabel :: Int -> Vector R
  47. toLabel n = fromList [ if i == n then 1 else 0 | i <- [0..9]]
  48. fromLabel :: Vector R -> Int
  49. fromLabel vec = snd $ maximumBy (comparing fst) (zip (toList vec) [0..])
  50. drawImage :: Vector R -> String
  51. drawImage vec = concatMap toLine (chunksOf 28 (toList vec))
  52. where toLine ps = (map toChar ps) ++ "\n"
  53. toChar v
  54. | v > 0.5 = 'o'
  55. | otherwise = '.'
  56. drawSample :: Sample Double -> String
  57. drawSample sample = drawImage (fst sample)
  58. ++ "Label: "
  59. ++ show (fromLabel $ snd sample)
  60. trainIms :: IO [(Int, Vector R)]
  61. trainIms = readImages "mnist-data/train-images-idx3-ubyte.gz"
  62. "mnist-data/train-labels-idx1-ubyte.gz"
  63. 50000
  64. trainSamples :: IO (Samples Double)
  65. trainSamples = do
  66. ims <- trainIms
  67. return $! [ img --> toLabel lbl | (lbl, img) <- ims]
  68. testIms :: IO [(Int, Vector R)]
  69. testIms = readImages "mnist-data/t10k-images-idx3-ubyte.gz"
  70. "mnist-data/t10k-labels-idx1-ubyte.gz"
  71. 10000
  72. testSamples :: IO (Samples Double)
  73. testSamples = do
  74. ims <- testIms
  75. return $! [ img --> toLabel lbl | (lbl, img) <- ims]
  76. newtype StorableSample a = StorableSample { getSample :: Sample a }
  77. instance (Element a, Binary a) => Binary (StorableSample a) where
  78. put (StorableSample (vecIn, vecOut)) = do
  79. put (toList vecIn)
  80. put (toList vecOut)
  81. get = do
  82. vecIn <- get
  83. vecOut <- get
  84. return $ StorableSample (fromList vecIn --> fromList vecOut)
  85. loadSamples :: FilePath -> IO (Samples Double)
  86. loadSamples fp = fmap (map getSample . decode . decompress) (B.readFile fp)
  87. saveSamples :: FilePath -> Samples Double -> IO ()
  88. saveSamples fp = B.writeFile fp . compress . encode . map StorableSample
  89. classify :: Network Double -> Sample Double -> IO ()
  90. classify net spl = do
  91. putStrLn (drawImage (fst spl)
  92. ++ "Recognized as "
  93. ++ show (fromLabel $ output net activation (fst spl)))
  94. test :: Network Double -> Samples Double -> Int
  95. test net = let f (img, lbl) = (fromLabel $ output net sigmoid img, fromLabel lbl) in
  96. hits . map f
  97. where hits :: [(Int, Int)] -> Int
  98. hits result = sum $ map (\(a, b) -> if a == b then 1 else 0) result
  99. activation :: (Floating a) => ActivationFunction a
  100. activation = sigmoid
  101. activation' :: (Floating a) => ActivationFunctionDerivative a
  102. activation' = sigmoid'
  103. main = do
  104. args <- cmdArgs arguments
  105. net <- case filePath args of
  106. "" -> newNetwork [784, hiddenNeurons args, 10]
  107. fp -> loadNetwork fp :: IO (Network Double)
  108. trSamples <- loadSamples "mnist-data/trainSamples.spls"
  109. tstSamples <- loadSamples "mnist-data/testSamples.spls"
  110. let bad = tstSamples `deepseq` test net tstSamples
  111. putStrLn $ "Initial performance of network: recognized " ++ show bad ++ " of 10k"
  112. let debug network epochs = "Left epochs: " ++ show epochs
  113. ++ " recognized: " ++ show (test network tstSamples) ++ " of 10k"
  114. smartNet <- trSamples `deepseq` (trainShuffled
  115. (epochs args) debug net
  116. (toCostFunction $ costFunction args)
  117. (lambda args) trSamples
  118. (miniBatchSize args)
  119. (eta args))
  120. let res = test smartNet tstSamples
  121. putStrLn $ "finished testing. recognized: " ++ show res ++ " of 10k"
  122. putStrLn "saving network"
  123. saveNetwork "mnist.net" smartNet