Haskell Artificial Neural Networking library
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

330 строки
12KB

  1. {-# LANGUAGE ScopedTypeVariables, FlexibleContexts, BangPatterns #-}
  2. {-# OPTIONS -Wall #-}
  3. -- |
  4. -- Module : Network
  5. -- Copyright : (c) 2017 Christian Merten
  6. -- Maintainer : c.merten@gmx.net
  7. -- Stability : experimental
  8. -- Portability : GHC
  9. --
  10. -- An implementation of artifical feed-forward neural networks in pure Haskell.
  11. --
  12. -- An example is added in /XOR.hs/
  13. module Network (
  14. -- * Network
  15. Network(..),
  16. Layer(..),
  17. newNetwork,
  18. output,
  19. -- * Learning functions
  20. trainShuffled,
  21. trainNTimes,
  22. CostFunction(..),
  23. getDelta,
  24. LearningRate,
  25. Lambda,
  26. TrainingDataLength,
  27. Sample, Samples, (-->),
  28. -- * Activation functions
  29. ActivationFunction, ActivationFunctionDerivative,
  30. sigmoid,
  31. sigmoid',
  32. -- * Network serialization
  33. saveNetwork,
  34. loadNetwork
  35. ) where
  36. import Data.List.Split (chunksOf)
  37. import Data.List (foldl')
  38. import Data.Binary
  39. import Data.Maybe (fromMaybe)
  40. import Text.Read (readMaybe)
  41. import System.Directory
  42. import System.Random
  43. import Control.Monad (zipWithM, forM)
  44. import Data.Array.IO
  45. import Debug.Trace (trace)
  46. import Text.Regex.PCRE
  47. import Numeric.LinearAlgebra
  48. -- | The generic feedforward network type, a binary instance is implemented.
  49. -- It takes a list of layers
  50. -- with a minimum of one (output layer).
  51. -- It is usually constructed using the `newNetwork` function.
  52. data Network a = Network { layers :: [Layer a] }
  53. deriving (Show)
  54. -- | One layer of a network, storing the weights matrix and the biases vector
  55. -- of this layer.
  56. data Layer a = Layer { weights :: Matrix a, biases :: Vector a }
  57. deriving (Show)
  58. instance (Element a, Binary a) => Binary (Network a) where
  59. put (Network ls) = put ls
  60. get = Network `fmap` get
  61. instance (Element a, Binary a) => Binary (Layer a) where
  62. put (Layer ws bs) = do
  63. put (toLists ws)
  64. put (toList bs)
  65. get = do
  66. ws <- get
  67. bs <- get
  68. return $ Layer (fromLists ws) (fromList bs)
  69. -- | Cost Function Enum
  70. data CostFunction = QuadraticCost
  71. | CrossEntropyCost
  72. deriving (Show, Eq)
  73. -- | getDelta based on the raw input, the activated input and the desired output
  74. -- results in different values depending on the CostFunction type.
  75. getDelta :: Floating a => CostFunction -> a -> a -> a -> a
  76. getDelta QuadraticCost z a y = (a - y) * sigmoid'(z)
  77. getDelta CrossEntropyCost _ a y = a - y
  78. -- | Activation function used to calculate the actual output of a neuron.
  79. -- Usually the 'sigmoid' function.
  80. type ActivationFunction a = a -> a
  81. -- | The derivative of an activation function.
  82. type ActivationFunctionDerivative a = a -> a
  83. -- | Training sample that can be used for the training functions.
  84. --
  85. -- > trainingData :: Samples Double
  86. -- > trainingData = [ fromList [0, 0] --> fromList [0],
  87. -- > fromList [0, 1] --> fromList [1],
  88. -- > fromList [1, 0] --> fromList [1],
  89. -- > fromList [1, 1] --> fromList [0]]
  90. type Sample a = (Vector a, Vector a)
  91. -- | A list of 'Sample's
  92. type Samples a = [Sample a]
  93. -- | A simple synonym for the (,) operator, used to create samples very intuitively.
  94. (-->) :: Vector a -> Vector a -> Sample a
  95. (-->) = (,)
  96. -- | The learning rate, affects the learning speed, lower learning rate results
  97. -- in slower learning, but usually better results after more epochs.
  98. type LearningRate = Double
  99. -- | Lambda value affecting the regularization while learning.
  100. type Lambda = Double
  101. -- | Wrapper around the training data length.
  102. type TrainingDataLength = Int
  103. -- | Initializes a new network with random values for weights and biases
  104. -- in all layers.
  105. --
  106. -- > net <- newNetwork [2, 3, 4]
  107. newNetwork :: [Int] -> IO (Network Double)
  108. newNetwork layerSizes
  109. | length layerSizes < 2 = error "Network too small!"
  110. | otherwise = do
  111. lays <- zipWithM go (init layerSizes) (tail layerSizes)
  112. return $ Network lays
  113. where go :: Int -> Int -> IO (Layer Double)
  114. go inputSize outputSize = do
  115. ws <- randn outputSize inputSize
  116. seed <- randomIO
  117. let bs = randomVector seed Gaussian outputSize
  118. return $ Layer ws bs
  119. -- | Calculate the output of the network based on the network, a given
  120. -- 'ActivationFunction' and the input vector.
  121. output :: (Numeric a, Num (Vector a))
  122. => Network a
  123. -> ActivationFunction a
  124. -> Vector a
  125. -> Vector a
  126. output net act input = foldl' f input (layers net)
  127. where f vec layer = cmap act ((weights layer #> vec) + biases layer)
  128. rawOutputs :: (Numeric a, Num (Vector a))
  129. => Network a
  130. -> ActivationFunction a
  131. -> Vector a
  132. -> [(Vector a, Vector a)]
  133. rawOutputs net act input = scanl f (input, input) (layers net)
  134. where f (_, a) layer = let z' = (weights layer #> a) + biases layer in
  135. (z', cmap act z')
  136. -- | The most used training function, randomly shuffling the training set before
  137. -- every training epoch
  138. --
  139. -- > trainShuffled 30 (\n e -> "") net CrossEntropyCost 0.5 trainData 10 0.1
  140. trainShuffled :: Int
  141. -> (Network Double -> Int -> String)
  142. -> Network Double
  143. -> CostFunction
  144. -> Lambda
  145. -> Samples Double
  146. -> Int
  147. -> Double
  148. -> IO (Network Double)
  149. trainShuffled 0 _ net _ _ _ _ _ = return net
  150. trainShuffled epochs debug net costFunction lambda trainSamples miniBatchSize eta = do
  151. spls <- shuffle trainSamples
  152. let !net' = trainSGD net costFunction lambda spls miniBatchSize eta
  153. trace (debug net' epochs)
  154. (trainShuffled (epochs - 1) debug net' costFunction lambda trainSamples miniBatchSize eta)
  155. -- | Pure version of 'trainShuffled', training the network /n/ times without
  156. -- shuffling the training set, resulting in slightly worse results.
  157. trainNTimes :: Int
  158. -> (Network Double -> Int -> String)
  159. -> Network Double
  160. -> CostFunction
  161. -> Lambda
  162. -> Samples Double
  163. -> Int
  164. -> Double
  165. -> Network Double
  166. trainNTimes 0 _ net _ _ _ _ _ = net
  167. trainNTimes epochs debug net costFunction lambda trainSamples miniBatchSize eta =
  168. trace (debug net' epochs)
  169. (trainNTimes (epochs - 1) debug net' costFunction lambda trainSamples miniBatchSize eta)
  170. where !net' = trainSGD net costFunction lambda trainSamples miniBatchSize eta
  171. trainSGD :: (Numeric Double, Floating Double)
  172. => Network Double
  173. -> CostFunction
  174. -> Lambda
  175. -> Samples Double
  176. -> Int
  177. -> Double
  178. -> Network Double
  179. trainSGD net costFunction lambda trainSamples miniBatchSize eta =
  180. foldl' updateMiniBatch net (chunksOf miniBatchSize trainSamples)
  181. where updateMiniBatch = update eta costFunction lambda (length trainSamples)
  182. update :: LearningRate
  183. -> CostFunction
  184. -> Lambda
  185. -> TrainingDataLength
  186. -> Network Double
  187. -> Samples Double
  188. -> Network Double
  189. update eta costFunction lambda n net spls = case newNablas of
  190. Nothing -> net
  191. Just x -> net { layers = layers' x }
  192. where newNablas :: Maybe [Layer Double]
  193. newNablas = foldl' updateNablas Nothing spls
  194. updateNablas :: Maybe [Layer Double] -> Sample Double -> Maybe [Layer Double]
  195. updateNablas mayNablas sample =
  196. let !nablasDelta = backprop net costFunction sample
  197. f nabla nablaDelta =
  198. nabla { weights = weights nabla + weights nablaDelta,
  199. biases = biases nabla + biases nablaDelta }
  200. in case mayNablas of
  201. Just nablas -> Just $ zipWith f nablas nablasDelta
  202. Nothing -> Just $ nablasDelta
  203. layers' :: [Layer Double] -> [Layer Double]
  204. layers' nablas = zipWith updateLayer (layers net) nablas
  205. updateLayer :: Layer Double -> Layer Double -> Layer Double
  206. updateLayer layer nabla =
  207. let w = weights layer -- weights matrix
  208. nw = weights nabla
  209. b = biases layer -- biases vector
  210. nb = biases nabla
  211. fac = 1 - eta * (lambda / fromIntegral n)
  212. w' = scale fac w - scale (eta / (fromIntegral $ length spls)) nw
  213. b' = b - scale (eta / (fromIntegral $ length spls)) nb
  214. in layer { weights = w', biases = b' }
  215. backprop :: Network Double -> CostFunction -> Sample Double -> [Layer Double]
  216. backprop net costFunction spl = finalNablas
  217. where rawFeedforward :: [(Vector Double, Vector Double)]
  218. rawFeedforward = reverse $ rawOutputs net sigmoid (fst spl)
  219. -- get starting activation and raw value
  220. headZ, headA :: Vector Double
  221. (headZ, headA) = head rawFeedforward
  222. -- get starting delta, based on the activation of the last layer
  223. startDelta = getDelta costFunction headZ headA (snd spl)
  224. -- calculate nabla of biases
  225. lastNablaB = startDelta
  226. -- calculate nabla of weighs of last layer in advance
  227. lastNablaW = startDelta `outer` previousA
  228. where previousA
  229. | length rawFeedforward > 1 = snd $ rawFeedforward !! 1
  230. | otherwise = fst spl
  231. lastLayer = Layer { weights = lastNablaW, biases = lastNablaB }
  232. -- reverse layers, analogy to the reversed (z, a) list
  233. layersReversed = reverse $ layers net
  234. -- calculate nablas, beginning at the end of the network (startDelta)
  235. (finalNablas, _) = foldl' calculate ([lastLayer], startDelta)
  236. [1..length layersReversed - 1]
  237. -- takes the index and updates nablas
  238. calculate (nablas, oldDelta) idx =
  239. let -- extract raw and activated value
  240. (z, _) = rawFeedforward !! idx
  241. -- apply prime derivative of sigmoid
  242. z' = cmap sigmoid' z
  243. -- calculate new delta
  244. w = weights $ layersReversed !! (idx - 1)
  245. delta = (tr w #> oldDelta) * z'
  246. -- nablaB is just the delta vector
  247. nablaB = delta
  248. -- activation in previous layer
  249. aPrevious = snd $ rawFeedforward !! (idx + 1)
  250. -- dot product of delta and the activation in the previous layer
  251. nablaW = delta `outer` aPrevious
  252. -- put nablas into a new layer
  253. in (Layer { weights = nablaW, biases = nablaB } : nablas, delta)
  254. -- | The sigmoid function
  255. sigmoid :: Floating a => ActivationFunction a
  256. sigmoid x = 1 / (1 + exp (-x))
  257. -- | The derivative of the sigmoid function.
  258. sigmoid' :: Floating a => ActivationFunctionDerivative a
  259. sigmoid' x = sigmoid x * (1 - sigmoid x)
  260. shuffle :: [a] -> IO [a]
  261. shuffle xs = do
  262. ar <- newArr n xs
  263. forM [1..n] $ \i -> do
  264. j <- randomRIO (i,n)
  265. vi <- readArray ar i
  266. vj <- readArray ar j
  267. writeArray ar j vi
  268. return vj
  269. where
  270. n = length xs
  271. newArr :: Int -> [a] -> IO (IOArray Int a)
  272. newArr len lst = newListArray (1,len) lst
  273. -- | Saves the network as the given filename. When the file already exists,
  274. -- it looks for another filename by increasing the version, e.g
  275. -- /mnist.net/ becomes /mnist1.net/.
  276. saveNetwork :: (Element a, Binary a) => FilePath -> Network a -> IO ()
  277. saveNetwork fp net = do
  278. ex <- doesFileExist fp
  279. case ex of
  280. True -> saveNetwork (newFileName fp) net
  281. False -> encodeFile fp net
  282. newFileName :: FilePath -> FilePath
  283. newFileName fp = case fp =~ "(.+[a-z]){0,1}([0-9]*)(\\..*)" :: [[String]] of
  284. [[_, p, v, s]] -> p ++ show (version v + 1) ++ s
  285. _ -> fp ++ "l"
  286. where version :: String -> Int
  287. version xs = fromMaybe 0 (readMaybe xs :: Maybe Int)
  288. -- | Load the network with the given filename.
  289. loadNetwork :: (Element a, Binary a) => FilePath -> IO (Network a)
  290. loadNetwork = decodeFile