瀏覽代碼

force strict computation

master
erichhasl 8 年之前
父節點
當前提交
1935ad124f
共有 2 個檔案被更改,包括 31 行新增9 行删除
  1. +23
    -3
      MNIST.hs
  2. +8
    -6
      Network.hs

+ 23
- 3
MNIST.hs 查看文件

@@ -3,15 +3,18 @@
import qualified Data.ByteString.Lazy as BS

import Data.Int
import Data.Binary
import Data.List (findIndex, maximumBy)
import Data.List.Split (chunksOf)
import Data.Ord (comparing)
import Debug.Trace (trace)

import Codec.Compression.GZip
import qualified Data.ByteString.Lazy as B

import Control.DeepSeq

import Numeric.LinearAlgebra
import Codec.Compression.GZip
import System.Random
import System.Environment (getArgs)
import System.Console.CmdArgs.Implicit
@@ -92,6 +95,23 @@ testSamples = do
ims <- testIms
return $! [ img --> toLabel lbl | (lbl, img) <- ims]

newtype StorableSample a = StorableSample { getSample :: Sample a }

instance (Element a, Binary a) => Binary (StorableSample a) where
put (StorableSample (vecIn, vecOut)) = do
put (toList vecIn)
put (toList vecOut)
get = do
vecIn <- get
vecOut <- get
return $ StorableSample (fromList vecIn --> fromList vecOut)

loadSamples :: FilePath -> IO (Samples Double)
loadSamples fp = fmap (map getSample . decode . decompress) (B.readFile fp)

saveSamples :: FilePath -> Samples Double -> IO ()
saveSamples fp = B.writeFile fp . compress . encode . map StorableSample

classify :: Network Double -> Sample Double -> IO ()
classify net spl = do
putStrLn (drawImage (fst spl)
@@ -115,8 +135,8 @@ main = do
net <- case filePath args of
"" -> newNetwork [784, hiddenNeurons args, 10]
fp -> loadNetwork fp :: IO (Network Double)
trSamples <- trainSamples
tstSamples <- testSamples
trSamples <- loadSamples "mnist-data/trainSamples.spls"
tstSamples <- loadSamples "mnist-data/testSamples.spls"
let bad = tstSamples `deepseq` test net tstSamples
putStrLn $ "Initial performance of network: recognized " ++ show bad ++ " of 10k"
let debug network epochs = "Left epochs: " ++ show epochs


+ 8
- 6
Network.hs 查看文件

@@ -40,6 +40,7 @@ module Network (
) where

import Data.List.Split (chunksOf)
import Data.List (foldl')
import Data.Binary
import Data.Maybe (fromMaybe)
import Text.Read (readMaybe)
@@ -146,7 +147,7 @@ output :: (Numeric a, Num (Vector a))
-> ActivationFunction a
-> Vector a
-> Vector a
output net act input = foldl f input (layers net)
output net act input = foldl' f input (layers net)
where f vec layer = cmap act ((weights layer #> vec) + biases layer)

rawOutputs :: (Numeric a, Num (Vector a))
@@ -206,7 +207,7 @@ trainSGD :: (Numeric Double, Floating Double)
-> Double
-> Network Double
trainSGD net costFunction lambda trainSamples miniBatchSize eta =
foldl updateMiniBatch net (chunksOf miniBatchSize trainSamples)
foldl' updateMiniBatch net (chunksOf miniBatchSize trainSamples)
where updateMiniBatch = update eta costFunction lambda (length trainSamples)

update :: LearningRate
@@ -220,10 +221,10 @@ update eta costFunction lambda n net spls = case newNablas of
Nothing -> net
Just x -> net { layers = layers' x }
where newNablas :: Maybe [Layer Double]
newNablas = foldl updateNablas Nothing spls
newNablas = foldl' updateNablas Nothing spls
updateNablas :: Maybe [Layer Double] -> Sample Double -> Maybe [Layer Double]
updateNablas mayNablas sample =
let nablasDelta = backprop net costFunction sample
let !nablasDelta = backprop net costFunction sample
f nabla nablaDelta =
nabla { weights = weights nabla + weights nablaDelta,
biases = biases nabla + biases nablaDelta }
@@ -252,8 +253,9 @@ backprop net costFunction spl = finalNablas
(headZ, headA) = head rawFeedforward
-- get starting delta, based on the activation of the last layer
startDelta = getDelta costFunction headZ headA (snd spl)
-- calculate weighs of last layer in advance
-- calculate nabla of biases
lastNablaB = startDelta
-- calculate nabla of weighs of last layer in advance
lastNablaW = startDelta `outer` previousA
where previousA
| length rawFeedforward > 1 = snd $ rawFeedforward !! 1
@@ -262,7 +264,7 @@ backprop net costFunction spl = finalNablas
-- reverse layers, analogy to the reversed (z, a) list
layersReversed = reverse $ layers net
-- calculate nablas, beginning at the end of the network (startDelta)
(finalNablas, _) = foldl calculate ([lastLayer], startDelta)
(finalNablas, _) = foldl' calculate ([lastLayer], startDelta)
[1..length layersReversed - 1]
-- takes the index and updates nablas
calculate (nablas, oldDelta) idx =


Loading…
取消
儲存