Haskell Artificial Neural Networking library
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

134 lignes
4.7KB

  1. {-# LANGUAGE DeriveDataTypeable, BangPatterns #-}
  2. import qualified Data.ByteString.Lazy as BS
  3. import Data.Int
  4. import Data.List (findIndex, maximumBy)
  5. import Data.List.Split (chunksOf)
  6. import Data.Ord (comparing)
  7. import Debug.Trace (trace)
  8. import Control.DeepSeq
  9. import Numeric.LinearAlgebra
  10. import Codec.Compression.GZip
  11. import System.Random
  12. import System.Environment (getArgs)
  13. import System.Console.CmdArgs.Implicit
  14. import Network
  15. data Cost = Quadratic | CrossEntropy
  16. deriving (Show, Data)
  17. toCostFunction :: Cost -> CostFunction
  18. toCostFunction Quadratic = QuadraticCost
  19. toCostFunction CrossEntropy = CrossEntropyCost
  20. data Arguments = Arguments { eta :: Double, lambda :: Double,
  21. filePath :: FilePath, costFunction :: Cost,
  22. epochs :: Int, miniBatchSize :: Int,
  23. hiddenNeurons :: Int }
  24. deriving (Show, Data, Typeable)
  25. arguments = Arguments { eta = 0.5 &= help "Learning rate",
  26. lambda = 5 &= help "Lambda of regularization",
  27. filePath = "" &= help "Load network from file",
  28. costFunction = Quadratic &= help "Cost function",
  29. epochs = 30 &= help "Number of training epochs",
  30. miniBatchSize = 10 &= help "Mini batch size",
  31. hiddenNeurons = 30 &= help "Number of neurons in hidden layer" }
  32. &= summary "MNIST Image classifier in Haskell v1"
  33. readImages :: FilePath -> FilePath -> Int64 -> IO ([(Int, Vector R)])
  34. readImages !imgPath !lblPath !n = do
  35. imgBytes <- fmap decompress (BS.readFile imgPath)
  36. lblBytes <- fmap decompress (BS.readFile lblPath)
  37. let !imgs = map (readImage imgBytes) [0..n-1]
  38. !lbs = map (readLabel lblBytes) [0..n-1]
  39. return $! zip lbs imgs
  40. readImage :: BS.ByteString -> Int64 -> Vector R
  41. readImage !bytes !n = vector $! map ((/256) . fromIntegral . BS.index bytes . (n*28^2 + 16 +)) [0..783]
  42. readLabel :: BS.ByteString -> Int64 -> Int
  43. readLabel bytes n = fromIntegral $! BS.index bytes (n + 8)
  44. toLabel :: Int -> Vector R
  45. toLabel n = fromList [ if i == n then 1 else 0 | i <- [0..9]]
  46. fromLabel :: Vector R -> Int
  47. fromLabel vec = snd $ maximumBy (comparing fst) (zip (toList vec) [0..])
  48. drawImage :: Vector R -> String
  49. drawImage vec = concatMap toLine (chunksOf 28 (toList vec))
  50. where toLine ps = (map toChar ps) ++ "\n"
  51. toChar v
  52. | v > 0.5 = 'o'
  53. | otherwise = '.'
  54. drawSample :: Sample Double -> String
  55. drawSample sample = drawImage (fst sample)
  56. ++ "Label: "
  57. ++ show (fromLabel $ snd sample)
  58. trainIms :: IO [(Int, Vector R)]
  59. trainIms = readImages "mnist-data/train-images-idx3-ubyte.gz"
  60. "mnist-data/train-labels-idx1-ubyte.gz"
  61. 50000
  62. trainSamples :: IO (Samples Double)
  63. trainSamples = do
  64. ims <- trainIms
  65. return $! [ img --> toLabel lbl | (lbl, img) <- ims]
  66. testIms :: IO [(Int, Vector R)]
  67. testIms = readImages "mnist-data/t10k-images-idx3-ubyte.gz"
  68. "mnist-data/t10k-labels-idx1-ubyte.gz"
  69. 10000
  70. testSamples :: IO (Samples Double)
  71. testSamples = do
  72. ims <- testIms
  73. return $! [ img --> toLabel lbl | (lbl, img) <- ims]
  74. classify :: Network Double -> Sample Double -> IO ()
  75. classify net spl = do
  76. putStrLn (drawImage (fst spl)
  77. ++ "Recognized as "
  78. ++ show (fromLabel $ output net activation (fst spl)))
  79. test :: Network Double -> Samples Double -> Int
  80. test net = let f (img, lbl) = (fromLabel $ output net sigmoid img, fromLabel lbl) in
  81. hits . map f
  82. where hits :: [(Int, Int)] -> Int
  83. hits result = sum $ map (\(a, b) -> if a == b then 1 else 0) result
  84. activation :: (Floating a) => ActivationFunction a
  85. activation = sigmoid
  86. activation' :: (Floating a) => ActivationFunctionDerivative a
  87. activation' = sigmoid'
  88. main = do
  89. args <- cmdArgs arguments
  90. net <- case filePath args of
  91. "" -> newNetwork [784, hiddenNeurons args, 10]
  92. fp -> loadNetwork fp :: IO (Network Double)
  93. trSamples <- trainSamples
  94. tstSamples <- testSamples
  95. let bad = tstSamples `deepseq` test net tstSamples
  96. putStrLn $ "Initial performance of network: recognized " ++ show bad ++ " of 10k"
  97. let debug network epochs = "Left epochs: " ++ show epochs
  98. ++ " recognized: " ++ show (test network tstSamples) ++ " of 10k"
  99. smartNet <- trSamples `deepseq` (trainShuffled
  100. (epochs args) debug net
  101. (toCostFunction $ costFunction args)
  102. (lambda args) trSamples
  103. (miniBatchSize args)
  104. (eta args))
  105. let res = test smartNet tstSamples
  106. putStrLn $ "finished testing. recognized: " ++ show res ++ " of 10k"
  107. putStrLn "saving network"
  108. saveNetwork "mnist.net" smartNet