# Nested for Expression Discovery

In [1]:
import numpy as np
import random
import math
import copy

In [2]:
atoms = [1, 2, 3, '+', '-']
children = [0, 0, 0, 2, 2]
MaxLength = 11

def legalMoves (state, leaves):
    l = []
    for a in range (len (atoms)):
        if len (state) + leaves + children [a] <= MaxLength:
            l.append (a)
    return l

def play (state, move, leaves):
    state.append (move)
    return [state, leaves - 1 + children [move]]

def terminal (state, leaves):
    return leaves == 0

def score (state, i):
    if children [state [i]] == 0:
        return [atoms [state [i]], i + 1]
    if children [state [i]] == 2:
        a = atoms [state [i]]
        [s1,i] = score (state, i + 1)
        [s2,i] = score (state, i)
        if a == '+':
            return [s1 + s2, i]
        if a == '-':
            return [s1 - s2, i]


In [3]:
def playout (state, leaves):
    while not terminal (state, leaves):
        moves = legalMoves (state, leaves)
        move = moves [int(random.random () * len (moves))]
        [state, leaves] = play (state, move, leaves)
    return state

def nested (state, leaves, n):
    bestSequence = []
    bestScore = -10e9
    while not terminal (state, leaves):
        moves = legalMoves (state, leaves)
        for m in moves:
            s1 = copy.deepcopy (state)
            [s1, leaves1] = play (s1, m, leaves)
            if (n == 1):
                s1 = playout (s1, leaves1)
            else:
                s1 = nested (s1, leaves1, n - 1)
            [score1, i] = score (s1, 0)
            if score1 > bestScore:
                bestScore = score1
                bestSequence = s1
        [state, leaves] = play (state, bestSequence [len (state)], leaves)
    return state

In [4]:
import sys

def printExpression (state):
    for i in state:
        sys.stdout.write (str (atoms [i]) + ' ')
    sys.stdout.write ('\n')
    
def test ():
    for i in range (10):
        s = playout ([], 1)
        printExpression (s)
        print (score (s, 0) [0])
    for i in range (10):
        s = nested ([], 1, 2)
        printExpression (s)
        print (score (s, 0) [0])

test ()

+ - 1 1 - 2 3 
-1
+ + + 2 3 3 - 1 1 
8
2 
2
+ 3 3 
6
2 
2
2 
2
1 
1
3 
3
1 
1
+ 2 - 3 2 
3
+ 1 + + + 3 3 + 3 3 3 
16
+ 3 + + + 3 3 + 3 3 3 
18
- 2 - - 1 3 + + 3 3 3 
13
+ 2 + + 3 3 + + 3 3 3 
17
+ + 3 + 3 + + 3 3 3 3 
18
+ + + 3 + 3 + 3 3 3 3 
18
- + + 3 + 3 + 3 3 3 1 
14
+ + 3 3 + 3 + + 3 3 3 
18
+ + 3 3 + + 3 + 3 3 3 
18
+ 3 + + + + 3 3 3 3 3 
18


# Discovery of Exploration Terms

In [2]:
!wget https://www.lamsade.dauphine.fr/~cazenave/generate.zip
!unzip generate.zip

--2025-03-11 16:34:46--  https://www.lamsade.dauphine.fr/~cazenave/generate.zip
Résolution de www.lamsade.dauphine.fr (www.lamsade.dauphine.fr)… 193.48.71.250
Connexion à www.lamsade.dauphine.fr (www.lamsade.dauphine.fr)|193.48.71.250|:443… connecté.
requête HTTP transmise, en attente de la réponse… 200 OK
Taille : 183517862 (175M) [application/zip]
Enregistre : ‘generate.zip’


2025-03-11 16:38:21 (836 KB/s) - ‘generate.zip’ enregistré [183517862/183517862]

Archive:  generate.zip
   creating: generate/
  inflating: generate/946.data       
  inflating: generate/487.data       
  inflating: generate/636.data       
  inflating: generate/286.data       
  inflating: generate/651.data       
  inflating: generate/1026.data      
  inflating: generate/976.data       
  inflating: generate/762.data       
  inflating: generate/72.data        
  inflating: generate/167.data       
  inflating: generate/1442.data      
  inflating: generate/650.data       
  inflating: generate/918.data    

  inflating: generate/1544.data      
  inflating: generate/1507.data      
  inflating: generate/242.data       
  inflating: generate/596.data       
  inflating: generate/1336.data      
  inflating: generate/234.data       
  inflating: generate/792.data       
  inflating: generate/1229.data      
  inflating: generate/1098.data      
  inflating: generate/345.data       
  inflating: generate/575.data       
  inflating: generate/1564.data      
  inflating: generate/1213.data      
  inflating: generate/457.data       
  inflating: generate/1112.data      
  inflating: generate/618.data       
  inflating: generate/680.data       
  inflating: generate/516.data       
  inflating: generate/553.data       
  inflating: generate/1661.data      
  inflating: generate/574.data       
  inflating: generate/831.data       
  inflating: generate/382.data       
  inflating: generate/1390.data      
  inflating: generate/454.data       
  inflating: generate/305.data       
  inflating:

  inflating: generate/1582.data      
  inflating: generate/363.data       
  inflating: generate/748.data       
  inflating: generate/1420.data      
  inflating: generate/408.data       
  inflating: generate/720.data       
  inflating: generate/284.data       
  inflating: generate/729.data       
  inflating: generate/727.data       
  inflating: generate/1283.data      
  inflating: generate/1539.data      
  inflating: generate/1503.data      
  inflating: generate/1492.data      
  inflating: generate/1.data         
  inflating: generate/722.data       
  inflating: generate/615.data       
  inflating: generate/972.data       
  inflating: generate/700.data       
  inflating: generate/520.data       
  inflating: generate/1664.data      
  inflating: generate/204.data       
  inflating: generate/62.data        
  inflating: generate/836.data       
  inflating: generate/70.data        
  inflating: generate/593.data       
  inflating: generate/662.data       
  inflating:

  inflating: generate/1253.data      
  inflating: generate/112.data       
  inflating: generate/975.data       
  inflating: generate/114.data       
  inflating: generate/1552.data      
  inflating: generate/1510.data      
  inflating: generate/398.data       
  inflating: generate/1208.data      
  inflating: generate/1519.data      
  inflating: generate/102.data       
  inflating: generate/342.data       
  inflating: generate/1065.data      
  inflating: generate/115.data       
  inflating: generate/230.data       
  inflating: generate/333.data       
  inflating: generate/1139.data      
  inflating: generate/1477.data      
  inflating: generate/1106.data      
  inflating: generate/515.data       
  inflating: generate/9.data         
  inflating: generate/1310.data      
  inflating: generate/383.data       
  inflating: generate/1484.data      
  inflating: generate/252.data       
  inflating: generate/950.data       
  inflating: generate/1363.data      
  inflating:

  inflating: generate/1230.data      
  inflating: generate/774.data       
  inflating: generate/20.data        
  inflating: generate/781.data       
  inflating: generate/1668.data      
  inflating: generate/953.data       
  inflating: generate/585.data       
  inflating: generate/1669.data      
  inflating: generate/16.data        
  inflating: generate/642.data       
  inflating: generate/370.data       
  inflating: generate/36.data        
  inflating: generate/1275.data      
  inflating: generate/14.data        
  inflating: generate/1376.data      
  inflating: generate/1574.data      
  inflating: generate/519.data       
  inflating: generate/536.data       
  inflating: generate/1670.data      
  inflating: generate/648.data       
  inflating: generate/1383.data      
  inflating: generate/927.data       
  inflating: generate/349.data       
  inflating: generate/600.data       
  inflating: generate/109.data       
  inflating: generate/1136.data      
  inflating:

  inflating: generate/1013.data      
  inflating: generate/626.data       
  inflating: generate/1259.data      
  inflating: generate/497.data       
  inflating: generate/784.data       
  inflating: generate/49.data        
  inflating: generate/17.data        
  inflating: generate/843.data       
  inflating: generate/261.data       
  inflating: generate/1364.data      
  inflating: generate/592.data       
  inflating: generate/663.data       
  inflating: generate/1432.data      
  inflating: generate/913.data       
 extracting: generate/1653.data      
  inflating: generate/1221.data      
  inflating: generate/994.data       
  inflating: generate/1649.data      
  inflating: generate/1031.data      
  inflating: generate/1462.data      
  inflating: generate/269.data       
  inflating: generate/1339.data      
  inflating: generate/1592.data      
  inflating: generate/666.data       
  inflating: generate/605.data       
  inflating: generate/1107.data      
  inflating:

  inflating: generate/389.data       
  inflating: generate/1515.data      
  inflating: generate/1157.data      
  inflating: generate/6.data         
  inflating: generate/1049.data      
  inflating: generate/38.data        
  inflating: generate/1344.data      
  inflating: generate/1453.data      
  inflating: generate/1475.data      
  inflating: generate/229.data       
  inflating: generate/1655.data      
  inflating: generate/1569.data      
  inflating: generate/1352.data      
 extracting: generate/1692.data      
  inflating: generate/586.data       
  inflating: generate/1118.data      
  inflating: generate/56.data        
  inflating: generate/1640.data      
  inflating: generate/1247.data      
  inflating: generate/817.data       
  inflating: generate/1319.data      
  inflating: generate/1188.data      
  inflating: generate/1370.data      
  inflating: generate/1570.data      
  inflating: generate/1020.data      
  inflating: generate/1458.data      
  inflating:

  inflating: generate/1455.data      
  inflating: generate/565.data       
  inflating: generate/761.data       
  inflating: generate/1062.data      
  inflating: generate/8.data         
  inflating: generate/962.data       
  inflating: generate/28.data        
  inflating: generate/254.data       
  inflating: generate/496.data       
  inflating: generate/1147.data      
  inflating: generate/119.data       
  inflating: generate/1445.data      
  inflating: generate/1531.data      
  inflating: generate/1142.data      
  inflating: generate/776.data       
  inflating: generate/477.data       
  inflating: generate/1527.data      
  inflating: generate/549.data       
  inflating: generate/330.data       
  inflating: generate/424.data       
  inflating: generate/1326.data      
  inflating: generate/15.data        
  inflating: generate/1022.data      
  inflating: generate/1226.data      
  inflating: generate/862.data       
  inflating: generate/1354.data      
  inflating:

In [5]:
import os

nbData = 1000

moveKatago = []
color = []
evalKatago = []

moveData = []
priorMove = []
scoreData = []

for i in range (nbData):
    name = "generate/" + str (i) + ".data"
    if os.path.exists(name):
        fichier = open(name, "r")
        #print ('read' + name)                                                                                                                                                                              
        state = []
        label = [x for x in next(fichier).split()] # read first line                                                                                                                                        
        moveKatago.append (int (label [0]))
        color.append (label [1])
        evalKatago.append (float (label [2]))
        values = [x for x in next(fichier).split()]
        moves = []
        priors = []
        scores = []
        for i in range (10):
            values = [x for x in next(fichier).split()]
            moves.append (int (values [0]))
            priors.append (float (values [1]))
            nb = int (values [2])
            l = []
            for j in range (3, len (values)):
                l.append (float (values [j]))
            scores.append (l)
        fichier.close ()
        moveData.append (moves)
        priorMove.append (priors)
        scoreData.append (scores)
    else:
        print ('pb reading', name)

print (len(moveData))

1000


In [8]:
import copy
import numpy as np

def bestHalf (moves, sumScores, nbPlayouts):
    half = []
    notused = list(np.full(361,True))
    for x in range (int(np.ceil(len (moves) / 2))):
        best = -1.0
        bestMove = moves [0]
        for m in moves:
            if notused [m]:
                mu = sumScores [m] / nbPlayouts [m]
                if mu > best:
                    best = mu
                    bestMove = m
        notused [bestMove] = False
        half.append (bestMove)
    return half

count = 0
for i in range (len (moveData)):
    moves = copy.deepcopy (moveData [i])
    total = 128
    nbPlayouts = [0 for x in range (361)]
    sumScores = [0.0 for x in range (361)]
    while (len (moves) > 1):
        for m in moves:
            bestj = 0
            for j in range (len (moveData [i])):
                if m == moveData [i] [j]:
                    bestj = j
            for j in range (max (1, int (nb // (len (moves) * np.log2 (total))))):
                value = scoreData [i] [bestj] [nbPlayouts [m]]
                if color [i] == 'b':
                    value = 1 - value
                sumScores [m] += value
                nbPlayouts [m] += 1
        moves = bestHalf (moves, sumScores, nbPlayouts)                                                                                                                                            
    if moves [0] == moveKatago [i]:
        count = count + 1
print ('score SH =', count)


score SH = 162


In [9]:
import copy
import numpy as np

atoms = ['+', '*', 'sc', 'pr', '2']
children = [2, 2, 0, 0, 0]
MaxLength = 7

def legalMoves (state, leaves):
    l = []
    for a in range (len (atoms)):
        if len (state) + leaves + children [a] <= MaxLength:
            l.append (a)
    return l

def play (state, move, leaves):
    state.append (move)
    return [state, leaves - 1 + children [move]]

def terminal (state, leaves):
    return leaves == 0

def playout (state, leaves):
    while not terminal (state, leaves):
        moves = legalMoves (state, leaves)
        move = moves [int(random.random () * len (moves))]
        [state, leaves] = play (state, move, leaves)
    return state

def score (state, i, prior, sumScores):
    if children [state [i]] == 0:
        if state [i] == '2':
            return [2, i + 1]
        if state [i] == 'sc':
            return [sumScores, i + 1]
        if state [i] == 'pr':
            return [prior, i + 1]
    if children [state [i]] == 2:
        a = atoms [state [i]]
        [s1,i] = score (state, i + 1, prior, sumScores)
        [s2,i] = score (state, i, prior, sumScores)
        if a == '+':
            return [s1 + s2, i]
        if a == '*':
            return [s1 * s2, i]
        
def bestHalf (moves, sumScores, nbPlayouts, expression, prior):
    half = []
    notused = list(np.full(361,True))
    for x in range (int(np.ceil(len (moves) / 2))):
        best = -1.0
        bestMove = moves [0]
        for m in moves:
            if notused [m]:
                mu = score (expression, 0, prior [m], sumScores [m])
                if mu > best:
                    best = mu
                    bestMove = m
        notused [bestMove] = False
        half.append (bestMove)
    return half

def evaluate (expression):
    count = 0
    for i in range (len (moveData)):
        moves = copy.deepcopy (moveData [i])
        total = 128
        nbPlayouts = [0 for x in range (361)]
        sumScores = [0.0 for x in range (361)]
        priors = [0.0 for x in range (361)]
        while (len (moves) > 1):
            for m in moves:
                bestj = 0
                for j in range (len (moveData [i])):
                    if m == moveData [i] [j]:
                        bestj = j
                priors [m] = priorMove [i] [bestj]
                for j in range (max (1, int (nb // (len (moves) * np.log2 (total))))):
                    value = scoreData [i] [bestj] [nbPlayouts [m]]
                    if color [i] == 'b':
                        value = 1 - value
                    sumScores [m] += value
                    nbPlayouts [m] += 1
            moves = bestHalf (moves, sumScores, nbPlayouts, expression, priors)                                                                                                                                            
        if moves [0] == moveKatago [i]:
            count = count + 1
    return count

print ('score SH =', evaluate (['sc']))

TypeError: list indices must be integers or slices, not str