Haskell Artificial Neural Networking library
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

270 строки
10.0KB

  1. {-# LANGUAGE ScopedTypeVariables, FlexibleContexts, BangPatterns #-}
  2. {-# OPTIONS -Wall #-}
  3. module Network where
  4. import Data.List.Split (chunksOf)
  5. import Data.Binary
  6. import Data.Maybe (fromMaybe)
  7. import Text.Read (readMaybe)
  8. import System.Directory
  9. import System.Random
  10. import Control.Monad (zipWithM, forM)
  11. import Data.Array.IO
  12. import Debug.Trace (trace)
  13. import Text.Regex.PCRE
  14. import Numeric.LinearAlgebra
  15. -- | The generic feedforward network type, a binary instance is implemented.
  16. -- It takes a list of layers
  17. -- with a minimum of one (output layer).
  18. -- It is usually constructed using the `newNetwork` function, initializing the matrices
  19. -- with some default random values.
  20. --
  21. -- > net <- newNetwork [2, 3, 4]
  22. data Network a = Network { layers :: [Layer a] }
  23. deriving (Show)
  24. -- | One layer of a network, storing the weights matrix and the biases vector
  25. -- of this layer.
  26. data Layer a = Layer { weights :: Matrix a, biases :: Vector a }
  27. deriving (Show)
  28. instance (Element a, Binary a) => Binary (Network a) where
  29. put (Network ls) = put ls
  30. get = Network `fmap` get
  31. instance (Element a, Binary a) => Binary (Layer a) where
  32. put (Layer ws bs) = do
  33. put (toLists ws)
  34. put (toList bs)
  35. get = do
  36. ws <- get
  37. bs <- get
  38. return $ Layer (fromLists ws) (fromList bs)
  39. -- | Cost Function Enum
  40. data CostFunction = QuadraticCost
  41. | CrossEntropyCost
  42. deriving (Show, Eq)
  43. -- | getDelta based on the raw input, the activated input and the desired output
  44. -- results in different values depending on the CostFunction type.
  45. getDelta :: Floating a => CostFunction -> a -> a -> a -> a
  46. getDelta QuadraticCost z a y = (a - y) * sigmoid'(z)
  47. getDelta CrossEntropyCost _ a y = a - y
  48. type ActivationFunction a = a -> a
  49. type ActivationFunctionDerivative a = a -> a
  50. type Sample a = (Vector a, Vector a)
  51. type Samples a = [Sample a]
  52. -- | A simple synonym for the (,) operator, used to create samples very intuitively.
  53. (-->) :: Vector a -> Vector a -> Sample a
  54. (-->) = (,)
  55. type LearningRate = Double
  56. type Lambda = Double
  57. type TrainingDataLength = Int
  58. newNetwork :: [Int] -> IO (Network Double)
  59. newNetwork layerSizes
  60. | length layerSizes < 2 = error "Network too small!"
  61. | otherwise = do
  62. lays <- zipWithM go (init layerSizes) (tail layerSizes)
  63. return $ Network lays
  64. where go :: Int -> Int -> IO (Layer Double)
  65. go inputSize outputSize = do
  66. ws <- randn outputSize inputSize
  67. seed <- randomIO
  68. let bs = randomVector seed Gaussian outputSize
  69. return $ Layer ws bs
  70. output :: (Numeric a, Num (Vector a))
  71. => Network a
  72. -> ActivationFunction a
  73. -> Vector a
  74. -> Vector a
  75. output net act input = foldl f input (layers net)
  76. where f vec layer = cmap act ((weights layer #> vec) + biases layer)
  77. outputs :: (Numeric a, Num (Vector a))
  78. => Network a
  79. -> ActivationFunction a
  80. -> Vector a
  81. -> [Vector a]
  82. outputs net act input = scanl f input (layers net)
  83. where f vec layer = cmap act ((weights layer #> vec) + biases layer)
  84. rawOutputs :: (Numeric a, Num (Vector a))
  85. => Network a
  86. -> ActivationFunction a
  87. -> Vector a
  88. -> [(Vector a, Vector a)]
  89. rawOutputs net act input = scanl f (input, input) (layers net)
  90. where f (_, a) layer = let z' = (weights layer #> a) + biases layer in
  91. (z', cmap act z')
  92. -- | The most used training function, randomly shuffling the training set before
  93. -- every training epoch
  94. --
  95. -- > trainShuffled 30 (\n e -> "") net CrossEntropyCost 0.5 trainData 10 0.1
  96. trainShuffled :: Int
  97. -> (Network Double -> Int -> String)
  98. -> Network Double
  99. -> CostFunction
  100. -> Lambda
  101. -> Samples Double
  102. -> Int
  103. -> Double
  104. -> IO (Network Double)
  105. trainShuffled 0 _ net _ _ _ _ _ = return net
  106. trainShuffled epochs debug net costFunction lambda trainSamples miniBatchSize eta = do
  107. spls <- shuffle trainSamples
  108. let !net' = trainSGD net costFunction lambda spls miniBatchSize eta
  109. trace (debug net' epochs)
  110. (trainShuffled (epochs - 1) debug net' costFunction lambda trainSamples miniBatchSize eta)
  111. trainNTimes :: Int
  112. -> (Network Double -> Int -> String)
  113. -> Network Double
  114. -> CostFunction
  115. -> Lambda
  116. -> Samples Double
  117. -> Int
  118. -> Double
  119. -> Network Double
  120. trainNTimes 0 _ net _ _ _ _ _ = net
  121. trainNTimes epochs debug net costFunction lambda trainSamples miniBatchSize eta =
  122. trace (debug net' epochs)
  123. (trainNTimes (epochs - 1) debug net' costFunction lambda trainSamples miniBatchSize eta)
  124. where !net' = trainSGD net costFunction lambda trainSamples miniBatchSize eta
  125. trainSGD :: (Numeric Double, Floating Double)
  126. => Network Double
  127. -> CostFunction
  128. -> Lambda
  129. -> Samples Double
  130. -> Int
  131. -> Double
  132. -> Network Double
  133. trainSGD net costFunction lambda trainSamples miniBatchSize eta =
  134. foldl updateMiniBatch net (chunksOf miniBatchSize trainSamples)
  135. where updateMiniBatch = update eta costFunction lambda (length trainSamples)
  136. update :: LearningRate
  137. -> CostFunction
  138. -> Lambda
  139. -> TrainingDataLength
  140. -> Network Double
  141. -> Samples Double
  142. -> Network Double
  143. update eta costFunction lambda n net spls = case newNablas of
  144. Nothing -> net
  145. Just x -> net { layers = layers' x }
  146. where newNablas :: Maybe [Layer Double]
  147. newNablas = foldl updateNablas Nothing spls
  148. updateNablas :: Maybe [Layer Double] -> Sample Double -> Maybe [Layer Double]
  149. updateNablas mayNablas sample =
  150. let nablasDelta = backprop net costFunction sample
  151. f nabla nablaDelta =
  152. nabla { weights = weights nabla + weights nablaDelta,
  153. biases = biases nabla + biases nablaDelta }
  154. in case mayNablas of
  155. Just nablas -> Just $ zipWith f nablas nablasDelta
  156. Nothing -> Just $ nablasDelta
  157. layers' :: [Layer Double] -> [Layer Double]
  158. layers' nablas = zipWith updateLayer (layers net) nablas
  159. updateLayer :: Layer Double -> Layer Double -> Layer Double
  160. updateLayer layer nabla =
  161. let w = weights layer -- weights matrix
  162. nw = weights nabla
  163. b = biases layer -- biases vector
  164. nb = biases nabla
  165. fac = 1 - eta * (lambda / fromIntegral n)
  166. w' = scale fac w - scale (eta / (fromIntegral $ length spls)) nw
  167. b' = b - scale (eta / (fromIntegral $ length spls)) nb
  168. in layer { weights = w', biases = b' }
  169. backprop :: Network Double -> CostFunction -> Sample Double -> [Layer Double]
  170. backprop net costFunction spl = finalNablas
  171. where rawFeedforward :: [(Vector Double, Vector Double)]
  172. rawFeedforward = reverse $ rawOutputs net sigmoid (fst spl)
  173. -- get starting activation and raw value
  174. headZ, headA :: Vector Double
  175. (headZ, headA) = head rawFeedforward
  176. -- get starting delta, based on the activation of the last layer
  177. startDelta = getDelta costFunction headZ headA (snd spl)
  178. -- calculate weighs of last layer in advance
  179. lastNablaB = startDelta
  180. lastNablaW = startDelta `outer` previousA
  181. where previousA
  182. | length rawFeedforward > 1 = snd $ rawFeedforward !! 1
  183. | otherwise = fst spl
  184. lastLayer = Layer { weights = lastNablaW, biases = lastNablaB }
  185. -- reverse layers, analogy to the reversed (z, a) list
  186. layersReversed = reverse $ layers net
  187. -- calculate nablas, beginning at the end of the network (startDelta)
  188. (finalNablas, _) = foldl calculate ([lastLayer], startDelta)
  189. [1..length layersReversed - 1]
  190. -- takes the index and updates nablas
  191. calculate (nablas, oldDelta) idx =
  192. let -- extract raw and activated value
  193. (z, _) = rawFeedforward !! idx
  194. -- apply prime derivative of sigmoid
  195. z' = cmap sigmoid' z
  196. -- calculate new delta
  197. w = weights $ layersReversed !! (idx - 1)
  198. delta = (tr w #> oldDelta) * z'
  199. -- nablaB is just the delta vector
  200. nablaB = delta
  201. -- activation in previous layer
  202. aPrevious = snd $ rawFeedforward !! (idx + 1)
  203. -- dot product of delta and the activation in the previous layer
  204. nablaW = delta `outer` aPrevious
  205. -- put nablas into a new layer
  206. in (Layer { weights = nablaW, biases = nablaB } : nablas, delta)
  207. sigmoid :: Floating a => ActivationFunction a
  208. sigmoid x = 1 / (1 + exp (-x))
  209. sigmoid' :: Floating a => ActivationFunctionDerivative a
  210. sigmoid' x = sigmoid x * (1 - sigmoid x)
  211. shuffle :: [a] -> IO [a]
  212. shuffle xs = do
  213. ar <- newArr n xs
  214. forM [1..n] $ \i -> do
  215. j <- randomRIO (i,n)
  216. vi <- readArray ar i
  217. vj <- readArray ar j
  218. writeArray ar j vi
  219. return vj
  220. where
  221. n = length xs
  222. newArr :: Int -> [a] -> IO (IOArray Int a)
  223. newArr len lst = newListArray (1,len) lst
  224. saveNetwork :: (Element a, Binary a) => FilePath -> Network a -> IO ()
  225. saveNetwork fp net = do
  226. ex <- doesFileExist fp
  227. case ex of
  228. True -> saveNetwork (newFileName fp) net
  229. False -> encodeFile fp net
  230. newFileName :: FilePath -> FilePath
  231. newFileName fp = case fp =~ "(.+[a-z]){0,1}([0-9]*)(\\..*)" :: [[String]] of
  232. [[_, p, v, s]] -> p ++ show (version v + 1) ++ s
  233. _ -> fp ++ "l"
  234. where version :: String -> Int
  235. version xs = fromMaybe 0 (readMaybe xs :: Maybe Int)
  236. loadNetwork :: (Element a, Binary a) => FilePath -> IO (Network a)
  237. loadNetwork = decodeFile