Did a lot of stuff... Working towards the genetic algorithm
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Tue, 12 Mar 2019 03:19:50 +0000 (22:19 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Tue, 12 Mar 2019 03:19:50 +0000 (22:19 -0500)
conf.lua
docs/NEAT [new file with mode: 0644]
main.lua
src/genetics.lua [new file with mode: 0644]
src/input.lua
src/neuralnet.lua [new file with mode: 0644]
src/utils.lua
src/world.lua
test_nn.lua [new file with mode: 0644]

index 20228a82a1bac18fd483d21e02ee8a00f79f2a84..f2ca2fbcdd2722746ea3a59b926758a922fcdf7b 100644 (file)
--- a/conf.lua
+++ b/conf.lua
@@ -41,28 +41,18 @@ local KEYMAP = {
        FIRE_RIGHT = "right";
 }
 
-local BACK_COLOR = { 0.8, 0.8, 0.8 }
-local PLAYER_COLOR = { 0.3, 0.3, 0.7 }
-local ENEMY_COLOR = { 1, 0, 0 }
-local BULLET_COLOR = { 0.6, 0.6, 1.0 }
-
-local PLAYER_VISION_SEGMENTS = 16
-local PLAYER_VISION_DISTANCE = 20
-
-local ENEMY_SIZE = 20
-
 return {
        WINDOW_WIDTH = WINDOW_WIDTH;
        WINDOW_HEIGHT = WINDOW_HEIGHT;
        KEYS = KEYMAP;
 
-       BACK_COLOR = BACK_COLOR;
-       PLAYER_COLOR = PLAYER_COLOR;
-       ENEMY_COLOR = ENEMY_COLOR;
-       BULLET_COLOR = BULLET_COLOR;
+       BACK_COLOR = { 0.8, 0.8, 0.8 };
+       PLAYER_COLOR = { 0.3, 0.3, 0.7 };
+       ENEMY_COLOR = { 1.0, 0.0, 0.0 };
+       BULLET_COLOR = { 0.6, 0.6, 1.0 };
 
-       PLAYER_VISION_SEGMENTS = PLAYER_VISION_SEGMENTS;
-       PLAYER_VISION_DISTANCE = PLAYER_VISION_DISTANCE;
+       PLAYER_VISION_SEGMENTS = 16;
+       PLAYER_VISION_DISTANCE = 20;
 
-       ENEMY_SIZE = ENEMY_SIZE;
+       ENEMY_SIZE = 20;
 }
diff --git a/docs/NEAT b/docs/NEAT
new file mode 100644 (file)
index 0000000..71f985a
--- /dev/null
+++ b/docs/NEAT
@@ -0,0 +1,45 @@
+------------ Document to help me understand how NEAT works --------------
+
+The word neuron and node can be used interchangibly here.
+
+A genome describes a phenotype (in our case a neural network)
+A gene in the genome correspondes to a weight in the neural network
+The genome starts as simply the input nodes and the output nodes
+       - New nodes are added through mutation
+
+Mutation can be:
+       - Changing the genes (changing the weights)
+       - Adding a gene (adding a connection)
+       - Splitting a gene into two, creating a new node in the process
+       
+
+
+The basic process is:
+       1. Create an empty list of species (list of genomes)
+       1.5 Populate the list with the number of species
+       2. Mutate each member of the of the species a little bit
+       3. Run the trials
+       4. Remove the lowest performing members, say bottom 50%
+       6. Use breeding to combine the remaining members of the species into more members
+               (each step of this is called a generation)
+       7. Go to step 2
+
+So I guess I'm not using speciation....
+       That's okay for now
+
+
+
+
+
+
+
+In case I need it
+
+The compatibility formula is
+       delta = c1 * E / N + c2 * D / N + c3 * W
+
+       c1, c2, c3 are coefficients used to adjust the importance of the factors
+       N is number of genes in the larger genome
+       E is the number of excess genes
+       D is the number of disjoint genes
+
index fbb26b0b88d762e704d6459998583b7787261095..4325c73beae095a90c551c84ee89719f483b80a2 100644 (file)
--- a/main.lua
+++ b/main.lua
@@ -9,9 +9,9 @@ local Enemy = world_mod.Enemy
 local world, player
 local input
 function love.load()
-       world, player = World:new()
-       for i = 1, 100 do
-               local enemy = Enemy:new(math.random(800), math.random(600))
+       world, player = World.new()
+       for _ = 1, 100 do
+               local enemy = Enemy.new(math.random(800), math.random(600))
                world:add_entity(enemy)
        end
 
@@ -21,7 +21,7 @@ function love.load()
 end
 
 function love.keypressed(key)
-       input:keydown(key)      
+       input:keydown(key)
 end
 
 function love.keyreleased(key)
diff --git a/src/genetics.lua b/src/genetics.lua
new file mode 100644 (file)
index 0000000..8aa4871
--- /dev/null
@@ -0,0 +1,429 @@
+local NN = require "src.neuralnet"
+local NeuralNetwork = NN.NeuralNetwork
+
+-- Need a global-ish innovation number, since that depends on the whole training, not just a single genome
+local Current_Innovation = 1
+
+local function Get_Next_Innovation()
+       local tmp = Current_Innovation
+       Current_Innovation = Current_Innovation + 1
+       return tmp
+end
+
+-- N.E.A.T. genetic algorithm
+
+local Gene = {}
+local Gene_mt = { __index = Gene }
+
+function Gene.new()
+       local o = {
+               to = 0;
+               from = 0;
+               weight = 0;
+               enabled = true;
+               innovation = 0;
+       }
+
+       setmetatable(o, Gene_mt)
+       return o
+end
+
+function Gene:copy()
+       local new = Gene.new()
+
+       new.to = self.to
+       new.from = self.from
+       new.weight = self.weight
+       new.enabled = self.enabled
+       new.innovation = self.innovation
+
+       return new
+end
+
+
+-- Genome class --
+
+local Genome = {}
+local Genome_mt = { __index = Genome }
+
+function Genome.new(inputs, outputs)
+       local o = {
+               num_inputs = inputs + 1; -- We need one bias neuron that will always be 1
+               num_outputs = outputs;
+               genes = {};
+               fitness = 0;
+               network = {}; -- Neural Network
+               high_neuron = inputs + outputs; -- Highest numbered neuron in the genome
+
+               mutations = { -- The different chances of mutating a particular part of the genome
+                       ["weights"] = 1.0; -- Chance of changing the weights
+                       ["connection"] = 1.0; -- Chance of changing the connections (add a gene)
+                       ["bias"] = 1.0; -- Chance of connecting to the bias
+                       ["split"] = 1.0; -- Chance of splitting a gene and adding a neuron
+                       ["enable"] = 1.0; -- Chance of enabling a gene
+                       ["disable"] = 1.0; -- Chance of disablign a gene
+               }
+       }
+
+       setmetatable(o, Genome_mt)
+       return o
+end
+
+function Genome:add_gene(from, to, weight)
+       local gene = Gene.new()
+       gene.weight = weight
+       gene.from = from
+       gene.to = to
+       gene.innovation = Get_Next_Innovation()
+
+       table.insert(self.genes, gene)
+end
+
+function Genome:copy()
+       local newG = Genome.new()
+       for g = 1, #self.genes do
+               table.insert(newG.genes, self.genes[g]:copy())
+       end
+
+       newG.num_inputs = self.num_inputs
+       newG.num_ouputs = self.num_outputs
+
+       newG.high_neuron = self.high_neuron
+
+       for mut_name, val in pairs(self.mutations) do
+               newG.mutations[mut_name] = val
+       end
+
+       return newG
+end
+
+function Genome:create_network()
+       local net = NeuralNetwork.new(self.num_inputs, self.num_outputs)
+
+       for i = 1, #self.genes do
+               local gene = self.genes[i]
+
+               if gene.enabled then
+                       if not net:has_neuron(gene.to) then
+                               net:create_neuron(gene.to)
+                       end
+
+                       net:add_connection(gene)
+
+                       if not net:has_neuron(gene.from) then
+                               net:create_neuron(gene.from)
+                       end
+               end
+       end
+
+       self.network = net
+end
+
+function Genome:has_gene(from, to)
+       for i = 1, #self.genes do
+               local gene = self.genes[i]
+
+               if gene.to == to and gene.from == from then
+                       return true
+               end
+       end
+
+       return false
+end
+
+-- Randomly changes the genes (weights)
+function Genome:mutate_weights()
+       local change = 0.2
+
+       for i = 1, #self.genes do
+               local gene = self.genes[i]
+
+               -- Just some constant, probably put that somewhere... eventually
+               if math.random() < 0.8 then
+                       gene.weight = gene.weight + math.random() * change * 2 - change -- (-change, change)
+               else
+                       gene.weight = math.random() * 4 - 2 -- Randomly change it to be in (-2, 2)
+               end
+       end
+end
+
+-- Randomly adds a new gene (connection)
+function Genome:mutate_connections(connect_to_bias)
+       local neuron1 = self:get_random_neuron(true) -- Could be Input
+       local neuron2 = self:get_random_neuron(false) -- NOT an input
+
+       if connect_to_bias then
+               neuron2 = self.num_inputs -- This is going to be the id of the bias neuron
+       end
+
+       if self:has_gene(neuron1, neuron2) then
+               return
+       end
+
+       local weight = math.random() * 4 - 2
+       self:add_gene(neuron1, neuron2, weight)
+end
+
+-- Randomly splits a gene into 2 (adding a neuron in the process)
+function Genome:mutate_neuron()
+       if #self.genes == 0 then
+               return
+       end
+
+       self.high_neuron = self.high_neuron + 1
+
+       -- Get a random gene
+       local gene = self.genes[math.random(1, #self.genes)]
+
+       -- Skip the gene if it is not enabled
+       if not gene.enabled then
+               return
+       end
+
+       -- Disable the gene beacause we are about to add other to replace it
+       gene.enabled = false
+
+       local gene1 = gene:copy()
+       gene1.from = self.high_neuron
+       gene1.weight = 1.0
+       gene1.innovation = Get_Next_Innovation()
+       gene1.enabled = true
+
+       table.insert(self.genes, gene1)
+
+       local gene2 = gene:copy()
+       gene2.to = self.high_neuron
+       gene2.innovation = Get_Next_Innovation()
+       gene2.enabled = true
+
+       table.insert(self.genes, gene2)
+end
+
+-- Randomly turns on or off a gene, depending on the parameter
+function Genome:mutate_enabled(enabled)
+       local possible = {}
+       for _, gene in ipairs(self.genes) do
+               if gene.enabled == enabled then
+                       table.insert(possible, gene)
+               end
+       end
+
+       if #possible == 0 then
+               return
+       end
+
+       local gene = possible[math.random(1, #possible)]
+       gene.enabled = not gene.enabled
+end
+
+function Genome:mutate()
+       -- Randomize the rate that mutations can happen
+       for mut_name, rate in pairs(self.mutations) do
+               if math.random() < 0.5 then
+                       self.mutations[mut_name] = 0.96 * rate -- Slightly decrease rate
+               else
+                       self.mutations[mut_name] = 1.04 * rate -- Slightly increase rate
+               end
+       end
+
+       if math.random() < self.mutations["weights"] then
+               self:mutate_weights()
+       end
+
+       -- Randomly use the mutation functions above to create a slightly different genome
+       local prob = self.mutation["connections"]
+       while prob > 0 do
+               if math.random() < prob then
+                       self:mutate_connections(false)
+               end
+
+               prob = prob - 1
+       end
+
+       prob = self.mutation["bias"]
+       while prob > 0 do
+               if math.random() < prob then
+                       self:mutate_connections(true)
+               end
+
+               prob = prob - 1
+       end
+
+       prob = self.mutation["split"]
+       while prob > 0 do
+               if math.random() < prob then
+                       self:mutate_neuron()
+               end
+
+               prob = prob - 1
+       end
+
+       prob = self.mutation["enable"]
+       while prob > 0 do
+               if math.random() < prob then
+                       self:mutate_enabled(true)
+               end
+
+               prob = prob - 1
+       end
+
+       prob = self.mutation["disable"]
+       while prob > 0 do
+               if math.random() < prob then
+                       self:mutate_enabled(false)
+               end
+
+               prob = prob - 1
+       end
+end
+
+function Genome:get_random_neuron(can_be_input)
+       local genes = self.genes
+
+       local neurons = {}
+
+       if can_be_input then
+               for i = 1, self.num_inputs do
+                       neurons[i] = true
+               end
+       end
+
+       for o = 1, self.num_outputs do
+               neurons[o + self.num_inputs] = true
+       end
+
+       for i = 1, #genes do
+               if can_be_input or genes[i].to then
+                       neurons[genes[i].to] = true
+               end
+               if can_be_input or genes[i].from then
+                       neurons[genes[i].from] = true
+               end
+       end
+
+       -- This array is not necessarily continuous, so we have to count them in a horrible way
+       local cnt = 0
+       for _, _ in pairs(neurons) do
+               cnt = cnt + 1
+       end
+
+       local choice = math.random(1, cnt)
+
+       -- Also, we have to index them in a horrible way (probably will change this later)
+
+       for k, _ in pairs(neurons) do
+               choice = choice - 1
+
+               if choice == 0 then
+                       return k
+               end
+       end
+
+       return 0
+end
+
+function Genome:crossover(other)
+       -- Need to make sure that this instance has the better fitness
+       local genome1 = self
+       local genome2 = other
+
+       if genome1.fitness < genome2.fitness then
+               local tmp = genome1
+               genome1 = genome2
+               genome2 = tmp
+       end
+
+       local child = Genome.new(genome1.num_inputs, genome1.num_outputs)
+
+       -- Create a list of all the innovation numbers for the 2nd (worse) genome
+       local innov2 = {}
+       for i = 1, #genome2.genes do
+               local gene = genome2.genes[i]
+               innov2[gene.innovation] = gene
+       end
+
+       -- Create a list of the better innovation numbers, with a change of keeping the "bad" innovation
+       for i = 1, #genome1.genes do
+               local gene1 = genome1.genes[i]
+               local gene2 = innov2[gene1.innovation]
+
+               if gene2 ~= nil and math.random() > 0.5 and gene2.enabled then
+                       table.insert(child.genes, gene2:copy())
+               else
+                       table.insert(child.genes, gene1:copy())
+               end
+       end
+
+       child.high_neuron = math.max(genome1.high_neuron, genome2.high_neuron)
+
+       return child
+end
+
+-- Population class --
+
+local Population = {}
+local Population_mt = { __index = Population }
+
+function Population.new()
+       local o = {
+               genomes = {};
+               generation = 0;
+               high_fitness = 0;
+               avg_fitness = 0;
+       }
+
+       setmetatable(o, Population_mt)
+       return o
+end
+
+function Population:create_genomes(num, inputs, outputs)
+       local genomes = self.genomes
+
+       for i = 1, num do
+               genomes[i] = Genome.new(inputs, outputs)
+       end
+end
+
+function Population:breed_genome()
+       local genomes = self.genomes
+       local child
+
+       -- Another random constant that should be in a global variable
+       if math.random() < 0.4 then
+               local g1 = genomes[math.random(1, #genomes)]
+               local g2 = genomes[math.random(1, #genomes)]
+               child = g1:crossover(g2)
+       else
+               local g = genomes[math.random(1, #genomes)]
+               child = g:copy()
+       end
+
+       child:mutate()
+
+       return child
+end
+
+function Population:kill_worst()
+       -- This might be backwards
+       table.sort(self.genomes, function(a, b)
+               return a.fitness > b.fitness
+       end)
+
+       local count = math.floor(#self.genomes / 2)
+       for _ = 1, count do
+               table.remove(self.genomes) -- This removes the last (worst) genome
+       end
+
+       collectgarbage() -- Since we just freed a bunch of memory, best to do this now instead of letting it pile up
+end
+
+function Population:mate()
+       local count = #self.genomes
+
+       -- Double the population size
+       for _ = 1, count do
+               table.insert(self.genomes, self:breed_genome())
+       end
+
+       self.generation = self.generation + 1
+end
index 9002297e1412afcc1953b953ad9a73b76abaec9a..baafe4a1a59205890a3fd80b021fee9d4fce2184 100644 (file)
@@ -4,7 +4,7 @@ local KEYS = CONF.KEYS
 -- INPUT --
 
 local Input = {}
-function Input:new()
+function Input.new()
        local o = {
                move_up    = false;
                move_down  = false;
@@ -24,7 +24,7 @@ end
 
 -- Ugly way of righting it but I don't care (right now at least... :P)
 function Input:keydown(key)
-       if     key == KEYS.MOVE_UP    then self.move_up    = true 
+       if     key == KEYS.MOVE_UP    then self.move_up    = true
        elseif key == KEYS.MOVE_DOWN  then self.move_down  = true
        elseif key == KEYS.MOVE_LEFT  then self.move_left  = true
        elseif key == KEYS.MOVE_RIGHT then self.move_right = true
@@ -36,7 +36,7 @@ function Input:keydown(key)
 end
 
 function Input:keyup(key)
-       if     key == KEYS.MOVE_UP    then self.move_up    = false 
+       if     key == KEYS.MOVE_UP    then self.move_up    = false
        elseif key == KEYS.MOVE_DOWN  then self.move_down  = false
        elseif key == KEYS.MOVE_LEFT  then self.move_left  = false
        elseif key == KEYS.MOVE_RIGHT then self.move_right = false
diff --git a/src/neuralnet.lua b/src/neuralnet.lua
new file mode 100644 (file)
index 0000000..720157a
--- /dev/null
@@ -0,0 +1,130 @@
+-- Simple neural network implementation (perceptron)
+
+local Neuron = {}
+function Neuron.new()
+       local o = {
+               value = 0;
+               inputs = {};
+               dirty = false; -- Means that the value of the neuron has to be recalculated
+       }
+       return o
+end
+
+
+-- Every node has a ID which is used as the key to the neurons array
+
+local NeuralNetwork = {}
+local NeuralNetwork_mt = { __index = NeuralNetwork }
+
+function NeuralNetwork.new(num_inputs, num_outputs)
+       local o = {
+               neurons = {};
+               num_inputs = num_inputs;
+               num_outputs = num_outputs;
+               next_neuron = num_inputs + num_outputs + 1;
+       }
+
+       -- 1 to num_inputs are input nodes
+       for i = 1, num_inputs do
+               o.neurons[i] = Neuron.new()
+       end
+
+       -- num_inputs + 1 to num_inputs + num_outputs are output nodes
+       for i = num_inputs + 1, num_inputs + num_outputs do
+               o.neurons[i] = Neuron.new()
+       end
+
+       setmetatable(o, NeuralNetwork_mt)
+       return o
+end
+
+function NeuralNetwork:add_connection(from, to, weight, id)
+       local neurons = self.neurons
+
+       if type(from) == "table" then
+               table.insert(neurons[to].inputs, from)
+       else
+               table.insert(neurons[to].inputs, {
+                       to = to;
+                       from = from;
+                       weight = weight;
+                       id = id;
+               })
+       end
+end
+
+function NeuralNetwork:add_neuron()
+       self.neurons[self.next_neuron] = Neuron.new()
+       self.next_neuron = self.next_neuron + 1
+       return self.next_neuron - 1
+end
+
+function NeuralNetwork:create_neuron(num)
+       if self.next_neuron < num then
+               self.next_neuron = num + 1 -- Makes sure the next neuron won't override previous neurons
+       end
+
+       self.neurons[num] = Neuron.new()
+end
+
+function NeuralNetwork:has_neuron(num)
+       return self.neurons[num] ~= nil
+end
+
+function NeuralNetwork:activate(inputs)
+       local ns = self.neurons
+
+       for i = 1, self.num_inputs do
+               self.neurons[i].value = inputs[i]
+       end
+
+       for i = self.num_inputs + 1, #ns do
+               ns[i].dirty = true
+       end
+
+       -- Iterate backwards since the hidden nodes are going to be at the end of the array
+       for i = #ns, self.num_inputs + 1, -1 do
+               if ns[i].dirty then
+                       self:activate_neuron(i)
+               end
+       end
+end
+
+function NeuralNetwork:activate_neuron(neuron)
+       local n = self.neurons[neuron]
+
+       if not n.dirty then return end
+
+       if #n.inputs > 0 then
+               local sum = 0
+               for i = 1, #n.inputs do
+                       local e = n.inputs[i]
+                       if self.neurons[e.from].dirty then
+                               self:activate_neuron(e.from)
+                       end
+
+                       sum = sum + self.neurons[e.from].value * e.weight
+               end
+
+               n.value = math.sigmoid(sum)
+       else
+               n.value = 0
+       end
+
+       n.dirty = false
+end
+
+function NeuralNetwork:get_outputs()
+       local ret = {}
+
+       for i = 1, self.num_outputs do
+               ret[i] = self.neurons[i + self.num_inputs].value
+       end
+
+       return ret
+end
+
+return {
+       NeuralNetwork = NeuralNetwork;
+       Neuron = Neuron;
+}
index 2da31fa9e0f4190755421055730e6c9186e9a66e..f3ca6076dc780a5a7fc4e9f216c3f57cd661d615 100644 (file)
@@ -6,6 +6,11 @@ function math.sqrDist(x1, y1, x2, y2)
        return (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)
 end
 
+function math.sigmoid(x)
+       -- Fast sigmoid
+       return x / (1 + math.abs(x))
+end
+
 function math.genuuid()
        return ("xxxxxxxx-xxxx-4yxx-xxxxxxxx"):gsub('[xy]', function (c)
         local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb)
@@ -31,3 +36,4 @@ end
 function reversedipairs(t)
        return ripairsiter, t, #t + 1
 end
+
index 457683d8ad0072d62ace54bc0b27e21b87be630d..0cd6baeb069b5acbf7afd20c7c64e3d89de1b725 100644 (file)
@@ -8,7 +8,7 @@ local Bullet = {}
 local Bullet_mt = { __index = Bullet }
 Bullet.ENTITY_TYPE = "Bullet"
 
-function Bullet:new(x, y, vx, vy)
+function Bullet.new(x, y, vx, vy)
        local o = {
                x = x;
                y = y;
@@ -17,7 +17,7 @@ function Bullet:new(x, y, vx, vy)
                life = 80;
                alive = true;
        }
-       
+
        setmetatable(o, Bullet_mt)
        return o
 end
@@ -77,7 +77,7 @@ local Player = {}
 local Player_mt = { __index = Player }
 Player.ENTITY_TYPE = "Player"
 
-function Player:new()
+function Player.new()
        local o = {
                x = CONF.WINDOW_WIDTH / 2;
                y = CONF.WINDOW_HEIGHT / 2;
@@ -86,7 +86,7 @@ function Player:new()
 
                distances = {};
        }
-       
+
        setmetatable(o, Player_mt)
        return o
 end
@@ -109,7 +109,7 @@ function Player:update(dt, world, input)
                local firey = 0
 
                local FIRE_SPEED = 300
-               
+
                if input.fire_up    then firey = firey - 1 end
                if input.fire_down  then firey = firey + 1 end
                if input.fire_left  then firex = firex - 1 end
@@ -132,8 +132,8 @@ function Player:update(dt, world, input)
 end
 
 function Player:fire(vx, vy, world)
-       local bullet = Bullet:new(self.x, self.y, vx, vy)
-       world:add_entity(bullet)        
+       local bullet = Bullet.new(self.x, self.y, vx, vy)
+       world:add_entity(bullet)
 end
 
 function Player:get_rect()
@@ -153,6 +153,8 @@ function Player:get_distances(world)
 
                local hit_entity = false
                for j = 1, CONF.PLAYER_VISION_DISTANCE do
+                       if hit_entity then break end
+
                        local tx = self.x + dx * j
                        local ty = self.y + dy * j
 
@@ -161,7 +163,7 @@ function Player:get_distances(world)
                                        local ent_rect = e:get_rect()
 
                                        local toggle = false
-                                       for k = 0, 20 do
+                                       for _ = 0, 20 do
                                                dx = dx / 2
                                                dy = dy / 2
                                                tx = tx - dx
@@ -218,7 +220,7 @@ local Enemy = {}
 local Enemy_mt = { __index = Enemy }
 Enemy.ENTITY_TYPE = "Enemy"
 
-function Enemy:new(x, y)
+function Enemy.new(x, y)
        local o = {
                x = x;
                y = y;
@@ -231,7 +233,7 @@ end
 
 function Enemy:update(dt, world)
        local player = world.player
-       
+
        local a = math.atan2(player.y - self.y, player.x - self.x)
        local dx = math.cos(a)
        local dy = math.sin(a)
@@ -260,9 +262,9 @@ end
 
 local World = {}
 local World_mt = { __index = World }
-function World:new(player)
+function World.new(player)
        if player == nil then
-               player = Player:new()
+               player = Player.new()
        end
 
        local o = {
@@ -312,11 +314,11 @@ function World:remove_entity(ent_or_id)
                        break
                end
        end
-       
+
        table.remove(self.entities, pos)
 end
 
--- Assumes ent has x and y
+-- Assumes ent has x, y and get_rect
 function World:move_entity(ent, dx, dy)
        ent.x = ent.x + dx
        for _, e in ipairs(self.entities) do
diff --git a/test_nn.lua b/test_nn.lua
new file mode 100644 (file)
index 0000000..a5e2edc
--- /dev/null
@@ -0,0 +1,23 @@
+require "src.utils"
+local NN = require "src.neuralnet"
+local NeuralNetwork = NN.NeuralNetwork
+
+local net = NeuralNetwork.new(4, 4)
+
+net:activate({ 0, 1, 0, 1 })
+local tmp = net:get_outputs()
+for k, v in ipairs(tmp) do
+       print(k, v)
+end
+
+net:add_connection(1, 5, 1, 0)
+net:add_connection(2, 5, 1, 0)
+net:add_connection(3, 5, 1, 0)
+net:add_connection(4, 5, 1, 1)
+
+net:activate({ 1, 1, 1, 1 })
+tmp = net:get_outputs()
+for k, v in ipairs(tmp) do
+       print(k, v)
+end
+