Polishing the product and improving the network training
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Thu, 14 Mar 2019 05:51:36 +0000 (00:51 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Thu, 14 Mar 2019 05:51:36 +0000 (00:51 -0500)
conf.lua
docs/TODO [new file with mode: 0644]
main.lua
src/genetics.lua
src/input.lua
src/neuralnet.lua
src/trainer.lua [new file with mode: 0644]
src/utils.lua
src/world.lua

index c66ef232b619432223307d30665824e5921e20a4..6918574ad89a6c56e5520cc1528996d714ee6d3d 100644 (file)
--- a/conf.lua
+++ b/conf.lua
@@ -1,6 +1,6 @@
 -- Window will be small since the graphics are not very important
-local WINDOW_WIDTH = 800
-local WINDOW_HEIGHT = 600
+local WINDOW_WIDTH = 1200
+local WINDOW_HEIGHT = 800
 
 function love.conf(t)
        t.window.title = "Maching Learning Game"
@@ -46,8 +46,11 @@ return {
        WINDOW_HEIGHT = WINDOW_HEIGHT;
        KEYS = KEYMAP;
 
-       BACK_COLOR = { 0.8, 0.8, 0.8 };
-       PLAYER_COLOR = { 0.3, 0.3, 0.7 };
+       BACK_COLOR = { 0.1, 0.1, 0.15 };
+       FONT_COLOR = { 1.0, 1.0, 1.0 };
+
+       PLAYER_COLOR = { 0.7, 0.7, 0.96 };
+       PLAYER_VISION_COLOR = { 0.7, 0.7, 0.7 };
        ENEMY_COLOR = { 1.0, 0.0, 0.0 };
        BULLET_COLOR = { 0.6, 0.6, 1.0 };
 
@@ -57,4 +60,17 @@ return {
        ENEMY_SIZE = 14;
 
        MAX_NEURONS = 1024;
+
+       -- How many of the genomes tested survive
+       GENOME_THRESHOLD = 1 / 5;
+
+       Starting_Weights_Chance = 0.25;
+       Starting_Connection_Chance = 2.0;
+       Starting_Bias_Chance = 0.2;
+       Starting_Split_Chance = 0.5;
+       Starting_Enable_Chance = 0.2;
+       Starting_Disable_Chance = 0.4;
+
+       Reset_Weight_Chance = 0.9;
+       Crossover_Chance = 0.75;
 }
diff --git a/docs/TODO b/docs/TODO
new file mode 100644 (file)
index 0000000..3f64f0b
--- /dev/null
+++ b/docs/TODO
@@ -0,0 +1,12 @@
+Things to fix tonight...
+
+* Need to be able to save and load the generations
+
+\, main.lua has a lot of logic that can be split up
+       \, Should have a "Tester" class that encapsulates the updating and handling population growth
+       - Should have a Statistics class that calculates basic stats on a list of numbers
+
+- Way to "manually" play the game
+
+- Need to be able to train the AI without running the visuals
+       - Separate logic
index ce94f6d25b2c97a4f6513878b11b9f26c2f52381..3936dad2f9a291b85dd448029d363388fc4e9c59 100644 (file)
--- a/main.lua
+++ b/main.lua
@@ -3,6 +3,7 @@ local CONF = require "conf"
 local world_mod = require "src.world"
 local Input = require "src.input"
 local Gen = require "src.genetics"
+local Trainer = (require "src.trainer").Trainer
 
 local World = world_mod.World
 local Enemy = world_mod.Enemy
@@ -12,16 +13,19 @@ local Population = Gen.Population
 local world, player
 local input
 local pop
-local pop_update
+local trainer
 
 local update_speed = 30
 
+local ui_font
 local fitness_font
 
 local stored_fitnesses = {}
 
 local enemies = {}
 function love.load()
+       math.randomseed(os.time())
+
        world, player = World.new()
        local enemy = Enemy.new(0, 0)
        table.insert(enemies, enemy)
@@ -40,10 +44,13 @@ function love.load()
 
        pop = Population.new()
        pop:create_genomes(96, 16, 8)
-       pop_update = pop:evolve()
+
+       trainer = Trainer.new(pop, world, input)
+       trainer:initialize_training()
 
        love.graphics.setBackgroundColor(CONF.BACK_COLOR)
-       fitness_font = love.graphics.newFont(24)
+       ui_font = love.graphics.newFont(24)
+       fitness_font = love.graphics.newFont(32)
 end
 
 function love.keypressed(key)
@@ -54,70 +61,6 @@ function love.keyreleased(key)
        input:keyup(key)
 end
 
-local function get_random_pos()
-       local x = math.random(100) + math.random(100) + 600 * (math.random(2) - 1)
-       local y = math.random(100) + math.random(100) + 500 * (math.random(2) - 1)
-       return x, y
-end
-
-local function network_input(ins, dt)
-       player.alive = true
-       if ins[1] > 0.35 then input:keydown("w") else input:keyup("w") end
-       if ins[2] > 0.35 then input:keydown("s") else input:keyup("s") end
-       if ins[3] > 0.35 then input:keydown("a") else input:keyup("a") end
-       if ins[4] > 0.35 then input:keydown("d") else input:keyup("d") end
-       if ins[5] > 0.35 then input:keydown("left") else input:keyup("left") end
-       if ins[6] > 0.35 then input:keydown("right") else input:keyup("right") end
-       if ins[7] > 0.35 then input:keydown("up") else input:keyup("up") end
-       if ins[8] > 0.35 then input:keydown("down") else input:keyup("down") end
-
-       local last_x = player.x
-       local last_y = player.y
-
-       world:update(dt, input)
-
-       local fitness = math.sqrt(math.sqrDist(last_x, last_y, player.x, player.y))
-       fitness = fitness - (player.shot and 1 or 0)
-
-       local enemies_alive = 0
-       for _, v in ipairs(enemies) do
-               if v.alive then
-                       enemies_alive = enemies_alive + 1
-               else
-                       if not v.__tagged then
-                               v.__tagged = true
-                               fitness = fitness + 400
-                       end
-               end
-       end
-
-       if not player.alive or enemies_alive == 0 then
-               for _, v in ipairs(enemies) do
-                       world:remove_entity(v)
-               end
-
-               enemies = {}
-
-               for _ = 1, math.ceil((pop.generation + 1) / 10) do
-                       local enemy = Enemy.new(get_random_pos())
-                       world:add_entity(enemy)
-                       table.insert(enemies, enemy)
-               end
-
-               if player.alive then
-                       fitness = fitness + 2000
-               else
-                       player.x = 400
-                       player.y = 300
-               end
-       end
-
-       return fitness, player.alive
-end
-
-local function generation_step(avg_fitness, _, _)
-       table.insert(stored_fitnesses, avg_fitness)
-end
 
 function love.update(dt)
        if love.keyboard.isDown "escape" then
@@ -125,33 +68,15 @@ function love.update(dt)
        end
 
        if love.keyboard.isDown "z" then
-               update_speed = update_speed - 1
-               if update_speed < 1 then
-                       update_speed = 1
-               end
+               trainer:change_speed(-1)
        end
 
        if love.keyboard.isDown "x" then
-               update_speed = update_speed + 1
-               if update_speed > 60 then
-                       update_speed = 60
-               end
+               trainer:change_speed(1)
        end
 
-       for _ = 1, update_speed do
-               local dists = player:get_distances(world)
-
-               local inputs = {}
-               for i = 1, 16 do
-                       local v1 = dists[i * 2]
-                       local v2 = dists[(i * 2 + 1) % 32]
-                       local v3 = dists[(i * 2 - 1) % 32]
-
-                       inputs[i] = 1 - ((0.5 * v1 + 0.25 * v2 + 0.25 * v3) / (CONF.ENEMY_SIZE * CONF.PLAYER_VISION_DISTANCE))
-               end
-
-               pop_update = pop_update(inputs, network_input, generation_step, dt)
-       end
+       trainer:update(dt)
+       --world:update(dt, input)
 end
 
 local function plot_fitness(x, y, scale)
@@ -160,13 +85,13 @@ local function plot_fitness(x, y, scale)
        love.graphics.scale(scale, scale)
 
        love.graphics.setColor(0, 0, 0, 0.4)
-       love.graphics.rectangle("fill", -20, -20, 440, 240)
+       love.graphics.rectangle("fill", -20, -20, 680, 340)
 
        love.graphics.setFont(fitness_font)
-       love.graphics.setColor(1, 1, 1)
+       love.graphics.setColor(CONF.FONT_COLOR)
 
-       love.graphics.printf("Average fitness: " .. math.floor(pop.avg_fitness), 0, 0, 400, "left")
-       love.graphics.printf("Highest fitness: " .. math.floor(pop.high_fitness), 0, 20, 400, "left")
+       love.graphics.printf("Average fitness: " .. math.floor(pop.avg_fitness), 0, 0, 640, "left")
+       love.graphics.printf("Highest fitness: " .. math.floor(pop.high_fitness), 0, 32, 640, "left")
 
        local highest = 0
        for _, v in ipairs(stored_fitnesses) do
@@ -175,14 +100,14 @@ local function plot_fitness(x, y, scale)
                end
        end
 
-       local width = 400 / (#stored_fitnesses)
+       local width = 640 / (#stored_fitnesses)
 
        love.graphics.setColor(0, 0, 1)
        for i, v in ipairs(stored_fitnesses) do
                if v < 0 then
                        v = 0
                end
-               love.graphics.circle("fill", (i - 1) * width, 200 - v * 100 / highest, 8)
+               love.graphics.circle("fill", (i - 1) * width, 300 - v * 200 / highest, 8)
        end
 
        love.graphics.pop()
@@ -198,9 +123,19 @@ local function draw_network(net, x, y, scale)
        love.graphics.setColor(0, 0, 0, 0.4)
        love.graphics.rectangle("fill", -20, -20, 680, 600)
 
-       love.graphics.setColor(1, 1, 1)
-
        for _, v in pairs(net.neurons) do
+               local c = v.value
+               local r = c < 0 and c or 0
+               local b = c > 0 and c or 0
+               local g = 0
+
+               if v.value == 0 then
+                       r = 0.3
+                       g = 0.3
+                       b = 0.3
+               end
+
+               love.graphics.setColor(r, g, b)
                love.graphics.rectangle("fill", v.x, v.y, 24, 24)
        end
 
@@ -219,6 +154,11 @@ local function draw_network(net, x, y, scale)
                                col = { 0, 0, 1 }
                        end
 
+                       local mag = math.abs(other.value)
+                       col[1] = col[1] * mag
+                       col[2] = col[2] * mag
+                       col[3] = col[3] * mag
+
                        love.graphics.setColor(col)
                        love.graphics.setLineWidth(math.sigmoid(conn.weight) * 2)
                        love.graphics.line(x1, y1, x2, y2)
@@ -230,17 +170,20 @@ local function draw_network(net, x, y, scale)
 end
 
 function love.draw()
+       love.graphics.setScissor(0, 0, 820, 620)
        world:draw()
+       love.graphics.setScissor()
 
-       love.graphics.setColor(0, 0, 0)
-       love.graphics.printf(tostring(love.timer.getFPS()) .. " FPS", 0, 0, 800, "left")
-       love.graphics.printf("Generation: " .. pop.generation, 0, 32, 800, "left")
-       love.graphics.printf("Genome: " .. pop.current_genome, 0, 64, 800, "left")
+       love.graphics.setColor(CONF.FONT_COLOR)
+       love.graphics.setFont(ui_font)
+       love.graphics.printf(tostring(love.timer.getFPS()) .. " FPS", 16, 640, 800, "left")
+       love.graphics.printf("Generation: " .. pop.generation, 16, 640 + 32, 800, "left")
+       love.graphics.printf("Genome: " .. pop.current_genome, 16, 640 + 64, 800, "left")
        if pop.genomes[pop.current_genome] ~= nil then
-               love.graphics.printf("Fitness: " .. math.floor(pop.genomes[pop.current_genome].fitness), 0, 96, 800, "left")
+               love.graphics.printf("Fitness: " .. math.floor(pop.genomes[pop.current_genome].fitness), 16, 640 + 96, 800, "left")
 
-               draw_network(pop.genomes[pop.current_genome].network, 580, 0, 1 / 3)
+               draw_network(pop.genomes[pop.current_genome].network, 1200 - 350, 32, 1 / 2)
        end
 
-       plot_fitness(250, 0, 3 / 4)
+       --plot_fitness(1200 - 350, 352, 1 / 2)
 end
index 68bb5bba929f484565f7699f938d946e425cb008..1516ae18756c4ba8441be51e58ef525036cc1f55 100644 (file)
@@ -1,20 +1,7 @@
 local NN = require "src.neuralnet"
-local conf = require "conf"
+local CONF = require "conf"
 local NeuralNetwork = NN.NeuralNetwork
 
--- Globals
-local Starting_Weights_Chance = 0.25
-local Starting_Connection_Chance = 2.0
-local Starting_Bias_Chance = 0.2
-local Starting_Split_Chance = 0.5
-local Starting_Enable_Chance = 0.2
-local Starting_Disable_Chance = 0.4
-
-local Reset_Weight_Chance = 0.9
-local Crossover_Chance = 0.75
-
-local MAX_NEURONS = conf.MAX_NEURONS
-
 -- Need a global-ish innovation number, since that depends on the whole training, not just a single genome
 local Current_Innovation = 1
 
@@ -57,7 +44,6 @@ end
 
 -- Genome class --
 
-
 local Genome = {}
 local Genome_mt = { __index = Genome }
 
@@ -71,12 +57,12 @@ function Genome.new(inputs, outputs)
                high_neuron = inputs + 1; -- Highest numbered neuron in the genome
 
                mutations = { -- The different chances of mutating a particular part of the genome
-                       ["weights"] = Starting_Weights_Chance; -- Chance of changing the weights
-                       ["connection"] = Starting_Connection_Chance; -- Chance of changing the connections (add a gene)
-                       ["bias"] = Starting_Bias_Chance; -- Chance of connecting to the bias
-                       ["split"] = Starting_Split_Chance; -- Chance of splitting a gene and adding a neuron
-                       ["enable"] = Starting_Enable_Chance; -- Chance of enabling a gene
-                       ["disable"] = Starting_Disable_Chance; -- Chance of disablign a gene
+                       ["weights"] = CONF.Starting_Weights_Chance; -- Chance of changing the weights
+                       ["connection"] = CONF.Starting_Connection_Chance; -- Chance of changing the connections (add a gene)
+                       ["bias"] = CONF.Starting_Bias_Chance; -- Chance of connecting to the bias
+                       ["split"] = CONF.Starting_Split_Chance; -- Chance of splitting a gene and adding a neuron
+                       ["enable"] = CONF.Starting_Enable_Chance; -- Chance of enabling a gene
+                       ["disable"] = CONF.Starting_Disable_Chance; -- Chance of disablign a gene
                }
        }
 
@@ -152,7 +138,7 @@ function Genome:mutate_weights()
        for i = 1, #self.genes do
                local gene = self.genes[i]
 
-               if math.random() < Reset_Weight_Chance then
+               if math.random() < CONF.Reset_Weight_Chance then
                        gene.weight = gene.weight + math.random() * change * 2 - change -- (-change, change)
                else
                        gene.weight = math.random() * 4 - 2 -- Randomly change it to be in (-2, 2)
@@ -182,7 +168,7 @@ function Genome:mutate_connections(connect_to_bias)
        end
 
        -- Output cant be input
-       if neuron1 >= MAX_NEURONS - self.num_outputs then
+       if neuron1 >= CONF.MAX_NEURONS - self.num_outputs then
                return
        end
 
@@ -348,7 +334,7 @@ function Genome:get_random_neuron(can_be_input)
        end
 
        for o = 1, self.num_outputs do
-               neurons[MAX_NEURONS - o] = true
+               neurons[CONF.MAX_NEURONS - o] = true
        end
 
        for i = 1, #genes do
@@ -426,7 +412,9 @@ local Population_mt = { __index = Population }
 function Population.new()
        local o = {
                genomes = {};
+               genome_count = 0;
                generation = 0;
+               max_innovations = 0;
                current_genome = 0;
                high_fitness = 0;
                total_fitness = 0;
@@ -439,6 +427,7 @@ end
 
 function Population:create_genomes(num, inputs, outputs)
        local genomes = self.genomes
+       self.genome_count = num
 
        for i = 1, num do
                genomes[i] = Genome.new(inputs, outputs)
@@ -446,16 +435,18 @@ function Population:create_genomes(num, inputs, outputs)
        end
 end
 
-function Population:breed_genome()
+function Population:breed_genome(max_genome)
        local genomes = self.genomes
+       max_genome = max_genome or #genomes
+
        local child
 
-       if math.random() < Crossover_Chance then
-               local g1 = genomes[math.random(1, #genomes)]
-               local g2 = genomes[math.random(1, #genomes)]
+       if math.random() < CONF.Crossover_Chance then
+               local g1 = genomes[math.random(1, max_genome)]
+               local g2 = genomes[math.random(1, max_genome)]
                child = g1:crossover(g2)
        else
-               local g = genomes[math.random(1, #genomes)]
+               local g = genomes[math.random(1, max_genome)]
                child = g:copy()
        end
 
@@ -470,7 +461,7 @@ function Population:kill_worst()
                return a.fitness > b.fitness
        end)
 
-       local count = math.floor(2 * #self.genomes / 3)
+       local count = math.floor(#self.genomes * (1 - CONF.GENOME_THRESHOLD))
        for _ = 1, count do
                table.remove(self.genomes) -- This removes the last (worst) genome
        end
@@ -483,81 +474,79 @@ function Population:kill_worst()
 end
 
 function Population:mate()
-       local count = #self.genomes * 2
+       local start_count = #self.genomes
+       local count = self.genome_count - #self.genomes
 
-       -- Double the population size
        for _ = 1, count do
-               table.insert(self.genomes, self:breed_genome())
+               table.insert(self.genomes, self:breed_genome(start_count))
        end
 
-
        self.generation = self.generation + 1
 end
 
-function Population:evolve()
-       local evolve_test, finish_evolve
-
-       -- First we need to calculate the fitnesses of every genome
-       self.current_genome = 0
-       function evolve_test(inputs, output_func, _, ...)
-               if self.current_genome == 0 then
-                       self.current_genome = 1
-                       self.genomes[self.current_genome]:create_network()
-               end
+function Population:training_step(inputs, output_func, _, ...)
+       if self.current_genome == 0 then
+               self.current_genome = 1
+               self.genomes[self.current_genome]:create_network()
+       end
 
-               if self.current_genome <= #self.genomes then
-                       -- Assumes genome has network generated
-                       local genome = self.genomes[self.current_genome]
-                       inputs[#inputs + 1] = 1 -- Bias neuron
+       if self.current_genome <= #self.genomes then
+               -- Assumes genome has network generated
+               local genome = self.genomes[self.current_genome]
+               inputs[#inputs + 1] = 1 -- Bias neuron
 
-                       genome.network:activate(inputs)
+               genome.network:activate(inputs)
 
-                       local outputs = genome.network:get_outputs()
-                       local fitness_change, cont = output_func(outputs, ...)
+               local outputs = genome.network:get_outputs()
+               local fitness_change, cont = output_func(outputs, ...)
 
-                       genome.fitness = genome.fitness + fitness_change
+               genome.fitness = genome.fitness + fitness_change
 
-                       if cont then
-                               return evolve_test
-                       else
-                               if genome.fitness > self.high_fitness then
-                                       self.high_fitness = genome.fitness
-                               end
+               if cont then
+                       return self.training_step
+               else
+                       if genome.fitness > self.high_fitness then
+                               self.high_fitness = genome.fitness
+                       end
 
-                               self.total_fitness = self.total_fitness + genome.fitness
-                               self.avg_fitness = self.total_fitness / self.current_genome
+                       self.total_fitness = self.total_fitness + genome.fitness
+                       self.avg_fitness = self.total_fitness / self.current_genome
 
-                               self.current_genome = self.current_genome + 1
+                       self.current_genome = self.current_genome + 1
 
-                               if self.current_genome <= #self.genomes then
-                                       self.genomes[self.current_genome]:create_network()
-                                       return evolve_test
-                               else
-                                       return finish_evolve
-                               end
+                       if self.current_genome <= #self.genomes then
+                               self.genomes[self.current_genome]:create_network()
+                               return self.training_step
+                       else
+                               return self.evolve
                        end
-               else
-                       return finish_evolve
                end
+       else
+               return self.evolve
        end
+end
+
+function Population:evolve(_, _, generation_step, ...)
+       generation_step(self.avg_fitness, self.high_fitness, ...)
+       self:kill_worst()
+       self:mate()
+
+       self.current_genome = 0
+       self.high_fitness = 0
+       self.avg_fitness = 0
+       self.total_fitness = 0
 
+       return self.training_step
+end
+
+function Population:start_training()
+       -- First we need to calculate the fitnesses of every genome
        -- Then we need to kill off the worst of them
        -- Then we breed more
        -- Rinse and repeat!
-       function finish_evolve(_, _, generation_step, ...)
-               generation_step(self.avg_fitness, self.high_fitness, ...)
-
-               self:kill_worst()
-               self:mate()
 
-               self.current_genome = 0
-               self.high_fitness = 0
-               self.avg_fitness = 0
-               self.total_fitness = 0
-               return evolve_test
-       end
-
-       return evolve_test
+       self.current_genome = 0
+       return self.training_step
 end
 
 return {
index baafe4a1a59205890a3fd80b021fee9d4fce2184..a91d23cca0fab7ef65970fae6bbf1c3fa02784e1 100644 (file)
@@ -22,7 +22,6 @@ function Input.new()
        return o
 end
 
--- Ugly way of righting it but I don't care (right now at least... :P)
 function Input:keydown(key)
        if     key == KEYS.MOVE_UP    then self.move_up    = true
        elseif key == KEYS.MOVE_DOWN  then self.move_down  = true
index fd259bf27492c2691fcb7ee3d57774d69bfa4275..da4ca34bd4be14a1bafe37db3b74b0f44f47b60a 100644 (file)
@@ -1,5 +1,4 @@
-local conf = require "conf"
-local MAX_NEURONS = conf.MAX_NEURONS
+local CONF = require "conf"
 
 -- Simple neural network implementation (perceptron)
 
@@ -37,7 +36,7 @@ function NeuralNetwork.new(num_inputs, num_outputs)
 
        -- num_inputs + 1 to num_inputs + num_outputs are output nodes
        for i = 1, num_outputs do
-               o.neurons[MAX_NEURONS - i] = Neuron.new(600, (i - 1) * 32)
+               o.neurons[CONF.MAX_NEURONS - i] = Neuron.new(600, (i - 1) * 32)
        end
 
        setmetatable(o, NeuralNetwork_mt)
@@ -48,7 +47,6 @@ function NeuralNetwork:add_connection(from, to, weight, id)
        local neurons = self.neurons
 
        if type(from) == "table" then
-               assert(from.to ~= from.from, "NEURON GOING TO ITSELF")
                table.insert(neurons[from.to].inputs, from)
        else
                table.insert(neurons[to].inputs, {
@@ -61,7 +59,7 @@ function NeuralNetwork:add_connection(from, to, weight, id)
 end
 
 function NeuralNetwork:add_neuron()
-       self.neurons[self.next_neuron] = Neuron.new(math.random(500) + 100, math.random(400) + 50)
+       self.neurons[self.next_neuron] = Neuron.new(math.random(400) + 100, math.random(400) + 50)
        self.next_neuron = self.next_neuron + 1
        return self.next_neuron - 1
 end
@@ -92,7 +90,6 @@ function NeuralNetwork:activate(inputs)
                end
        end
 
-       -- Iterate backwards since the hidden nodes are going to be at the end of the array
        for i, _ in pairs(ns) do
                if ns[i].dirty then
                        self:activate_neuron(i)
@@ -128,7 +125,7 @@ function NeuralNetwork:get_outputs()
        local ret = {}
 
        for i = 1, self.num_outputs do
-               ret[i] = self.neurons[MAX_NEURONS - i].value
+               ret[i] = self.neurons[CONF.MAX_NEURONS - i].value
        end
 
        return ret
diff --git a/src/trainer.lua b/src/trainer.lua
new file mode 100644 (file)
index 0000000..b6f4c92
--- /dev/null
@@ -0,0 +1,128 @@
+local CONF = require "conf"
+
+local Trainer = {}
+local Trainer_mt = { __index = Trainer }
+
+function Trainer.new(population, world, input)
+       local o = {
+               world = world;
+               player = world.player;
+               input = input;
+               population = population;
+
+               population_step = nil;
+               after_inputs_func = nil;
+               generation_step_func = nil;
+
+               speed = 1;
+               max_speed = 60;
+       }
+
+       setmetatable(o, Trainer_mt)
+       return o
+end
+
+function Trainer:initialize_training()
+       self.population_step = self.population:start_training()
+
+       self.after_inputs_func = function(...)
+               return self:after_inputs(...)
+       end
+
+       self.generation_step_func = function(...)
+               return self:generation_step(...)
+       end
+end
+
+function Trainer:get_inputs()
+       local dists = self.player:get_distances(self.world)
+
+       local inputs = {}
+       for i = 1, 16 do
+               local v1 = dists[i * 2]
+               local v2 = dists[(i * 2 + 1) % 32]
+               local v3 = dists[(i * 2 - 1) % 32]
+
+               inputs[i] = 1 - ((0.5 * v1 + 0.25 * v2 + 0.25 * v3) / (CONF.ENEMY_SIZE * CONF.PLAYER_VISION_DISTANCE))
+       end
+
+       return inputs
+end
+
+function Trainer:after_inputs(inputs, dt)
+       -- Make sure the player is considered alive at the start of every turn
+       self.player.alive = true
+
+       if inputs[1] > 0.35 then self.input:keydown("w")     else self.input:keyup("w") end
+       if inputs[2] > 0.35 then self.input:keydown("s")     else self.input:keyup("s") end
+       if inputs[3] > 0.35 then self.input:keydown("a")     else self.input:keyup("a") end
+       if inputs[4] > 0.35 then self.input:keydown("d")     else self.input:keyup("d") end
+       if inputs[5] > 0.35 then self.input:keydown("up")    else self.input:keyup("up") end
+       if inputs[6] > 0.35 then self.input:keydown("down")  else self.input:keyup("down") end
+       if inputs[7] > 0.35 then self.input:keydown("left")  else self.input:keyup("left") end
+       if inputs[8] > 0.35 then self.input:keydown("right") else self.input:keyup("right") end
+
+       local last_x     = self.player.x
+       local last_y     = self.player.y
+       local last_kills = self.player.kills
+
+       self.world:update(dt, self.input)
+
+       local fitness = math.sqrt(math.sqrDist(last_x, last_y, self.player.x, self.player.y))
+
+       fitness = fitness - (self.player.shot and 1 or 0)
+       self.player.shot = false
+
+       if self.player.kills ~= last_kills then
+               fitness = fitness + 400 * (self.player.kills - last_kills)
+       end
+
+       if not self.player.alive or self.world:get_count{ "Enemy" } == 0 then
+               self.world:kill_all{ "Bullet", "Enemy" }
+
+               if self.player.alive then
+                       fitness = fitness + 2000
+                       self.world:next_round()
+               else
+                       self.world:reset()
+               end
+
+               self.world:spawn_enemies(self.world.round)
+       end
+
+       return fitness, self.player.alive
+end
+
+function Trainer:generation_step(avg, high, _)
+       print "PROCEEDING TO NEXT GENERATION"
+end
+
+function Trainer:update(dt)
+       local inputs = self:get_inputs()
+
+       for _ = 1, self.speed do
+               self.population_step = self.population_step(
+                       self.population,
+                       inputs,
+                       self.after_inputs_func,
+                       self.generation_step_func,
+                       dt
+               )
+       end
+end
+
+function Trainer:change_speed(delta)
+       self.speed = self.speed + delta
+
+       if self.speed < 1 then
+               self.speed = 1
+       end
+
+       if self.speed > self.max_speed then
+               self.speed = self.max_speed
+       end
+end
+
+return {
+       Trainer = Trainer;
+}
index f3ca6076dc780a5a7fc4e9f216c3f57cd661d615..af8ed63b104e30bf115bed0b12290bb10b5564b7 100644 (file)
@@ -26,6 +26,14 @@ function math.rectintersects(r1, r2)
        return r1[1] <= r2[1] + r2[3] and r1[2] <= r2[2] + r2[4] and r1[1] + r1[3] >= r2[1] and r1[2] + r1[4] >= r2[2]
 end
 
+function table.contains(t, v)
+       for _, a in pairs(t) do
+               if a == v then return true end
+       end
+
+       return false
+end
+
 local function ripairsiter(t, i)
        i = i - 1
        if i ~= 0 then
index d0b049c71970dd5e0812bc0e90ab632f3887f260..15ea87499b4768477cc9dd585fb352414215526c 100644 (file)
@@ -31,11 +31,14 @@ function Bullet:update(dt, world)
        end
 end
 
-function Bullet:collide(other, dx, dy, world)
+function Bullet:collide(other, _, _, world)
        if other.ENTITY_TYPE == "Enemy" then
                other.alive = false
                world:remove_entity(other)
                world:remove_entity(self)
+
+               -- Reward the player a kill
+               world.player.kills = world.player.kills + 1
        end
 end
 
@@ -80,14 +83,15 @@ Player.ENTITY_TYPE = "Player"
 
 function Player.new()
        local o = {
-               x = CONF.WINDOW_WIDTH / 2;
-               y = CONF.WINDOW_HEIGHT / 2;
+               x = 400;
+               y = 300;
                r = 20;
                alive = true;
                fire_cooldown = 0;
 
                distances = {};
                shot = false;
+               kills = 0;
        }
 
        setmetatable(o, Player_mt)
@@ -203,7 +207,7 @@ function Player:draw()
        love.graphics.setColor(CONF.PLAYER_COLOR)
        love.graphics.circle("fill", self.x, self.y, self.r)
 
-       love.graphics.setColor(0, 0, 0)
+       love.graphics.setColor(CONF.PLAYER_VISION_COLOR)
        for i = 0, CONF.PLAYER_VISION_SEGMENTS - 1 do
                local a = i * 2 * math.pi / CONF.PLAYER_VISION_SEGMENTS
                local dx = math.cos(a)
@@ -217,7 +221,7 @@ function Player:draw()
 
                if self.distances[i + 1] > 0 then
                        local d = self.distances[i + 1]
-                       love.graphics.circle("fill", self.x + dx * d, self.y + dy * d, 5)
+                       love.graphics.circle("fill", self.x + dx * d, self.y + dy * d, 8)
                end
        end
 end
@@ -295,6 +299,8 @@ function Wall:update(dt)
 end
 
 function Wall:draw()
+       love.graphics.setColor(0, 0, 0)
+       love.graphics.rectangle("fill", unpack(self:get_rect()))
 end
 
 function Wall:get_rect()
@@ -325,6 +331,7 @@ function World.new(player)
                entities = {};
 
                player = player;
+               round = 1;
        }
 
        setmetatable(o, World_mt)
@@ -339,6 +346,11 @@ function World:update(dt, input)
        end
 
        self.player:update(dt, self, input)
+
+       -- if self:get_count{ "Enemy" } == 0 and self.player.alive then
+       --      self:next_round()
+       --      self:spawn_enemies(self.round)
+       -- end
 end
 
 function World:add_entity(ent)
@@ -394,6 +406,59 @@ function World:move_entity(ent, dx, dy)
        end
 end
 
+function World:get_count(types)
+       local cnt = 0
+       for _, v in ipairs(self.entities) do
+               if table.contains(types, v.ENTITY_TYPE) then
+                       cnt = cnt + 1
+               end
+       end
+
+       return cnt
+end
+
+function World:kill_all(types)
+       local i = 0
+
+       -- Because we are deleting from the list as we go, we have to
+       -- do this iteratively
+       repeat
+               i = i + 1
+
+               if self.entities[i] ~= nil then
+                       if table.contains(types, self.entities[i].ENTITY_TYPE) then
+                               self:remove_entity(self.entities[i])
+                               i = i - 1
+                       end
+               end
+       until i == #self.entities
+end
+
+function World:spawn_enemies(count)
+       for _ = 1, count do
+               local vert = math.random(2) > 1
+               local tmp = math.random(2) > 1
+               local x = math.random(vert and 100 or 800) + (vert and (tmp and 600 or 0) or 0)
+
+               vert = not vert
+               tmp = math.random(2) > 1
+               local y = math.random(vert and 100 or 600) + (vert and (tmp and 400 or 0) or 0)
+
+               local enemy = Enemy.new(x, y)
+               self:add_entity(enemy)
+       end
+end
+
+function World:next_round()
+       self.round = self.round + 1
+end
+
+function World:reset()
+       self.round = 1
+       self.player.x = 400
+       self.player.y = 300
+end
+
 function World:draw()
        for _, e in ipairs(self.entities) do
                e:draw()