浏览代码

improve performance of MCTS

montecarlo
父节点
当前提交
de36cf244c
签署人:: christian <christian@flavigny.de> GPG 密钥 ID: D953D69721B948B3
共有 7 个文件被更改,包括 126 次插入50 次删除
  1. +24
    -0
      src/Skat/AI/Base.hs
  2. +62
    -12
      src/Skat/AI/Games/Skat/Guess.hs
  3. +9
    -27
      src/Skat/AI/MonteCarlo.hs
  4. +23
    -3
      src/Skat/AI/Skat.hs
  5. +3
    -3
      src/Skat/AI/TicTacToe.hs
  6. +2
    -1
      src/Skat/Card.hs
  7. +3
    -4
      src/Skat/Pile.hs

+ 24
- 0
src/Skat/AI/Base.hs 查看文件

@@ -7,6 +7,8 @@

module Skat.AI.Base where

import System.Random (Random)
import qualified System.Random as Rand
import Control.Monad.State
import Control.Exception (assert)
import Control.Monad.Fail
@@ -46,3 +48,25 @@ class (MonadIO m, Show t, Show v, Show p, MonadGame t l v p m) => PlayableGame t

class Choose t m | m -> t where
choose :: m t

class MonadRandom m where
random :: Random a => m a
chooser :: [a] -> m a

instance MonadRandom IO where
random = Rand.randomIO
chooser [] = error "chooser: empty list"
chooser os = (os!!) <$> Rand.randomRIO (0, length os -1)

instance MonadRandom (State Rand.StdGen) where
random = do
gen <- get
let (a, gen') = Rand.random gen
put gen'
return a
chooser [] = error "chooser: empty list"
chooser os = do
gen <- get
let (a, gen') = Rand.randomR (0, length os -1) gen
put gen'
return (os !! a)

+ 62
- 12
src/Skat/AI/Games/Skat/Guess.hs 查看文件

@@ -8,6 +8,7 @@ import GHC.Generics (Generic, Generic1)
import Data.Ord
import Data.Monoid ((<>))
import Data.List
import Data.Set (Set)
import qualified Data.Set as S
import Control.Monad.State
import Control.Monad.Reader
@@ -18,6 +19,7 @@ import Data.Bits
import Debug.Trace

import Skat
import Skat.AI.Base
import Skat.Utils
import Skat.Card
import Skat.Pile
@@ -30,14 +32,14 @@ data Option = H Hand
| Skt
deriving (Show, Eq, Ord, Generic, NFData)

type Guess = Map Card [Option]
type Guess = Map Card (Set Option)

newGuess :: Guess
newGuess = newGuessWith allCards

newGuessWith :: [Card] -> Guess
newGuessWith cards = M.fromList l
where l = map (\c -> (c, [H Hand1, H Hand2, H Hand3, Skt])) cards
where l = map (\c -> (c, S.fromList [H Hand1, H Hand2, H Hand3, Skt])) cards

hasBeenPlayed :: Card -> Guess -> Guess
hasBeenPlayed card = M.delete card
@@ -45,33 +47,33 @@ hasBeenPlayed card = M.delete card
has :: Hand -> [Card] -> Guess -> Guess
has hand cs = M.mapWithKey f
where f card hands
| card `elem` cs = [H hand]
| card `elem` cs = S.singleton (H hand)
| otherwise = hands

hasOnly :: Hand -> [Card] -> Guess -> Guess
hasOnly hand cs = M.mapWithKey f
where f card hands
| card `elem` cs = [H hand]
| otherwise = delete (H hand) hands
| card `elem` cs = S.singleton (H hand)
| otherwise = S.delete (H hand) hands

hasOnly_ :: Option -> [Card] -> Guess -> Guess
hasOnly_ option cs = M.mapWithKey f
where f card hands
| card `elem` cs = [option]
| otherwise = h option hands
h a b = delete a b
| card `elem` cs = S.singleton option
| otherwise = S.delete option hands

hasNoLonger :: Trump -> Hand -> TurnColour -> Guess -> Guess
hasNoLonger trump hand effCol = M.mapWithKey f
where f card hands
| effectiveColour trump card == effCol && (H hand) `elem` hands = filter (/=H hand) hands
| effectiveColour trump card == effCol && (H hand) `S.member` hands =
S.filter (/=H hand) hands
| otherwise = hands

isSkat :: [Card] -> Guess -> Guess
isSkat cs = M.mapWithKey f
where f card hands
| card `elem` cs = [Skt]
| otherwise = if length cs == 2 then delete Skt hands else hands
| card `elem` cs = S.singleton Skt
| otherwise = if length cs == 2 then S.delete Skt hands else hands

choosen1 :: Int -> [a] -> [[a]]
choosen1 !n !cs = map f (filter ((==n) . popCount) [0..(m-1)])
@@ -112,7 +114,7 @@ distributions2 !guess1 !(n1, n2, n3, nskt) = do

carddist :: Option -> Int -> Guess -> [[Card]]
carddist option n guess = choosen n options
where options = M.keys $ M.filter (option `elem`) guess
where options = M.keys $ M.filter (option `S.member`) guess

carddistS :: Option -> Int -> StateT Guess [] [Card]
carddistS option n = do
@@ -130,6 +132,53 @@ distributions3 guess (n1, n2, n3, n4) = (flip evalStateT) guess $ do
return (hand1, hand2, hand3, skt)
where cardsPerHand = (length guess-2-n1-n2-n3) `div` 3

randomChoice :: (MonadRandom m, Monad m) => Set Option -> StateT (Int, Int, Int, Int) m Option
randomChoice options = do
--when (null options) $ error "randomChoice: options are empty"
(n1, n2, n3, n4) <- get
let g (H Hand1) = n1 > 0
g (H Hand2) = n2 > 0
g (H Hand3) = n3 > 0
g Skt = n4 > 0
opts = S.toList $ S.filter g options
--when (null opts) $ error "randomChoice: after filtering options are empty"
option <- lift $ chooser opts
let (n1', n2', n3', n4') = case option of
H Hand1 -> (n1-1, n2, n3, n4)
H Hand2 -> (n1, n2-1, n3, n4)
H Hand3 -> (n1, n2, n3-1, n4)
Skt -> (n1, n2, n3, n4-1)
put (n1', n2', n3', n4')
return option

randomGuess :: (MonadRandom m, Monad m) => Guess -> (Int, Int, Int, Int) -> m Guess
randomGuess guess (n1, n2, n3, n4) = (flip evalStateT) ( cardsPerHand + n1
, cardsPerHand + n2
, cardsPerHand + n3
, 2 + n4
) $ do
foldM helper guess (M.keys guess)
where cardsPerHand = (length guess-2-n1-n2-n3) `div` 3
helper g card = do
let opts = M.findWithDefault (error "findWithDefault") card g
o <- randomChoice opts
pure $ M.insert card (S.singleton o) g
randomDistr :: (MonadRandom m, Monad m) => Guess -> (Int, Int, Int, Int) -> m Distribution
randomDistr guess (n1, n2, n3, n4) = (flip evalStateT) ( cardsPerHand + n1
, cardsPerHand + n2
, cardsPerHand + n3
, 2 + n4
) $ do
randomGuess <- foldM helper guess (M.keys guess)
let [d] = distributions randomGuess (n1, n2, n3, n4)
pure d
where cardsPerHand = (length guess-2-n1-n2-n3) `div` 3
helper g card = do
let opts = M.findWithDefault (error "findWithDefault") card g
o <- randomChoice opts
pure $ M.insert card (S.singleton o) g
{-
distributions1 :: Guess -> (Int, Int, Int, Int) -> [Distribution]
distributions1 guess nos =
helper (sortBy compareGuess $ M.toList guess) nos
@@ -155,6 +204,7 @@ distributions1 guess nos =
isOk Skt = n4 < 2
in filterMap isOk (f card) hands
cardsPerHand = (length guess - 2) `div` 3
-}

distributions = distributions3



+ 9
- 27
src/Skat/AI/MonteCarlo.hs 查看文件

@@ -32,7 +32,7 @@ import qualified System.Random as Rand
import Text.Printf
import Data.List.Split

import Skat.AI.Base
import Skat.AI.Base hiding (simulate)
import qualified Skat as S
import qualified Skat.Card as S
import qualified Skat.Operations as S
@@ -98,26 +98,6 @@ valuation Pending{} = (0,0)

deriving instance (Show s, Show t) => Show (Tree t s)

class MonadRandom m where
random :: Random a => m a
chooser :: [a] -> m a

instance MonadRandom IO where
random = Rand.randomIO
chooser os = (os!!) <$> Rand.randomRIO (0, length os -1)

instance MonadRandom (State Rand.StdGen) where
random = do
gen <- get
let (a, gen') = Rand.random gen
put gen'
return a
chooser os = do
gen <- get
let (a, gen') = Rand.randomR (0, length os -1) gen
put gen'
return (os !! a)

{-
valuetonum :: (Fractional a, Value v) => v -> a
valuetonum v
@@ -143,6 +123,8 @@ class (Player p, Value d) => HasGameState t p d s | s -> d, s -> p, s -> t where
execute :: t -> s -> s
monteevaluate :: s -> d
current :: s -> p
simulate :: (Monad m, MonadRandom m) => s -> m d
simulate = montesimulate

montecarlo :: (Show s, Show t, Eq p, Show d, Monad m, HasGameState t p d s, MonadRandom m)
=> Tree t s
@@ -151,12 +133,12 @@ montecarlo (Pending state turn) = do
let currentTeam = current state
state' = execute turn state
-- objectively get a final score of random playout (independent of perspective)
values <- replicateM 1 (montesimulate state')
let tr = if maxing (current state') then id else invert
values <- replicateM 100 (simulate state')
let tr = if maxing (current state) then id else invert
vs = fmap (tonum . tr) values
n = sum vs / 1
n = sum vs
--let v = if maxing (current state') then value else invert value
let val = (n, 1)
let val = (n, 100)
pure $ Leaf state' False val
montecarlo (Leaf state terminal d)
| terminal || length ms == 0 = pure $ Leaf state True d
@@ -182,12 +164,12 @@ montecarlo n@(Node state _ d children)
else
if current state == current (treestate updated)
then diff
else 1 - diff
else fromIntegral (simruns updated) - diff
newWins = diff2 + fst d
--return $ trace ("updating node " ++ show diff2 ++ "\n" ++ show updated ++ "\n" ++ show bestChild) (Node state False (newWins, newSimRuns) cs)
return $ Node state False (newWins, newSimRuns) cs

montesimulate :: (Monad m, MonadRandom m, HasGameState t p d s, Show d)
montesimulate :: (Monad m, MonadRandom m, HasGameState t p d s)
=> s
-> m d
montesimulate state = case moves state of


+ 23
- 3
src/Skat/AI/Skat.hs 查看文件

@@ -165,6 +165,22 @@ instance HasGameState Turn Bool Float SkatState where
card <- ev S.allowedCards newEnv
pure $ Turn newEnv (S.toCard card)
where env = skatEnv s
simulate s
| Map.size (guess s) <= 2 = pure $ monteevaluate s
| otherwise = do
let currentPiles = ev (gets S.piles) env
table = S.tableCards currentPiles
n1 = length $ filter ((S.P S.Hand1==) . S.getPile) table
n2 = length $ filter ((S.P S.Hand2==) . S.getPile) table
n3 = length $ filter ((S.P S.Hand3==) . S.getPile) table
ns = (-n1, -n2, -n3, 0)
d <- randomDistr (guess s) ns
let newEnv = env { S.piles = updatePiles d (S.piles env) }
cards = ev S.allowedCards newEnv
card <- chooser cards
let newState = execute (Turn newEnv (S.toCard card)) s
Skat.AI.MonteCarlo.simulate newState
where env = skatEnv s

ev :: StateT S.SkatEnv (Writer [S.Trick]) a -> S.SkatEnv -> a
ev action = fst . runWriter . evalStateT action
@@ -186,9 +202,10 @@ playCLI n = do
liftIO $ putStrLn "iterating"
s <- get
let tree = Leaf s False (0, 0)
t = runmonte n (foldM (\tree _ -> montecarlo tree) tree [1..20])
t = runmonte n (foldM (\tree _ -> montecarlo tree) tree [1..100])
newstate = bestmove t
--liftIO $ putStrLn $ visualise t
liftIO $ print newstate
put newstate
lift (put $ skatEnv newstate)
else do
@@ -204,7 +221,7 @@ playCLI n = do
showBoard
liftIO $ getLine
-}
playCLI n
--playCLI n
where
--readTurn :: (MonadFail m, Read t, PlayableGame t l v p m) => m t
readTurn :: S.Skat (S.CardS S.Owner)
@@ -228,7 +245,7 @@ initSkatEnv n =
let gen = Rand.mkStdGen n
--cards = S.shuffle gen S.allCards
--piles = S.distribute cards
piles = S.cardDistr6
piles = S.cardDistr3
players = P.Players
(P.PL $ S.Stupid S.Single S.Hand1)
(P.PL $ S.Stupid S.Team S.Hand2)
@@ -257,3 +274,6 @@ initSkatState =
playSkat :: Int -> IO ()
playSkat n = let env = skatEnv initSkatState
in void $ S.evalSkat ( (flip runStateT) initSkatState (playCLI n) ) env

skattree :: Tree Turn SkatState
skattree = Leaf initSkatState False (0,0)

+ 3
- 3
src/Skat/AI/TicTacToe.hs 查看文件

@@ -210,12 +210,12 @@ playCLI n = do
if gameOver
then announceWinner
else do
current <- currentPlayer
--let current = False
--current <- currentPlayer
let current = False
if not current then do
s <- get
let tree = Leaf s False (0, 0)
t = bestmove $ runmonte n (foldM (\tree _ -> montecarlo tree) tree [1..5000])
t = bestmove $ runmonte n (foldM (\tree _ -> montecarlo tree) tree [1..1000])
put t
else do
showBoard


+ 2
- 1
src/Skat/Card.hs 查看文件

@@ -4,6 +4,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE BangPatterns #-}

module Skat.Card where

@@ -67,7 +68,7 @@ data TurnColour = TurnColour Colour
| Trump
deriving (Show, Eq)

data Card = Card Type Colour
data Card = Card !Type !Colour
deriving (Eq, Show, Ord, Read, Bounded, Generic)

getType :: Card -> Type


+ 3
- 4
src/Skat/Pile.hs 查看文件

@@ -292,14 +292,13 @@ cardDistr5 = makePiles hand1 hand2 hand3 tbl skt

cardDistr6 :: Piles
cardDistr6 = emptyPiles hand1 hand2 hand3 skt
where hand3 = [Card Ace Spades, Card Jack Diamonds, Card Jack Clubs, Card King Spades,
where hand1 = [Card Jack Diamonds, Card Jack Clubs, Card King Spades,
Card Nine Spades, Card Ace Diamonds, Card Queen Diamonds
]
hand1 = [Card Jack Spades, Card Jack Hearts, Card Ten Spades, Card Ace Hearts,
hand3 = [Card Jack Spades, Card Ten Spades, Card Ace Hearts,
Card Ten Hearts, Card Nine Hearts, Card Seven Clubs
]
hand2 = [Card Eight Spades, Card Queen Spades, Card Seven Spades, Card Seven Diamonds,
hand2 = [Card Queen Spades, Card Seven Spades, Card Seven Diamonds,
Card Seven Hearts, Card Eight Hearts, Card Queen Hearts
]
skt = [Card Nine Clubs, Card Queen Clubs]


正在加载...
取消
保存