Added statistics
authorBrendan Hansen <brendan.f.hansen@gmail.com>
Fri, 15 Mar 2019 18:55:40 +0000 (13:55 -0500)
committerBrendan Hansen <brendan.f.hansen@gmail.com>
Fri, 15 Mar 2019 18:55:40 +0000 (13:55 -0500)
conf.lua
main.lua
src/stats.lua [new file with mode: 0644]
src/trainer.lua

index 3cd30786466396fbcc7071bdf71c2419d42bf83a..d0e6318ddb9886559ab610425d4e064d3b2d8cfb 100644 (file)
--- a/conf.lua
+++ b/conf.lua
@@ -44,8 +44,8 @@ local KEYMAP = {
 
 return {
        -- GENERAL PROPERTIES
-       LOAD_FILE = "./saved/TERM_1_GEN_14";
-       SAVE_FILE = "./saved/TERM";
+       LOAD_FILE = "";
+       SAVE_FILE = "./saved/DELETE";
 
        WINDOW_WIDTH = WINDOW_WIDTH;
        WINDOW_HEIGHT = WINDOW_HEIGHT;
index cea19bbc514d3ad43f7a605566bd962ec91bea51..ebd675aac516ad31878e18261b204bfbc76865ef 100644 (file)
--- a/main.lua
+++ b/main.lua
@@ -5,6 +5,7 @@ local Input = require "src.input"
 local Gen = require "src.genetics"
 require "src.data"
 local Trainer = (require "src.trainer").Trainer
+local Stats = (require "src.stats").Stats
 
 local World = world_mod.World
 local Wall = world_mod.Wall
@@ -14,12 +15,12 @@ local world
 local input
 local pop
 local trainer
+local pop_stats
+local gen_stats
 
 local ui_font
 local fitness_font
 
-local stored_fitnesses = {}
-
 function love.load()
        math.randomseed(os.time())
 
@@ -43,6 +44,9 @@ function love.load()
                pop = Population.load(CONF.LOAD_FILE)
        end
 
+       pop_stats = Stats.new()
+       gen_stats = Stats.new()
+
        trainer = Trainer.new(pop, world, input)
        trainer:initialize_training()
 
@@ -72,7 +76,7 @@ function love.update(dt)
                trainer:change_speed(1)
        end
 
-       trainer:update(dt)
+       trainer:update(dt, pop_stats, gen_stats)
        --world:update(dt, input)
 end
 
@@ -90,21 +94,34 @@ local function plot_fitness(x, y, scale)
        love.graphics.printf("Average fitness: " .. math.floor(pop.avg_fitness), 0, 0, 640, "left")
        love.graphics.printf("Highest fitness: " .. math.floor(pop.high_fitness), 0, 32, 640, "left")
 
-       local highest = 0
-       for _, v in ipairs(stored_fitnesses) do
-               if v > highest then
-                       highest = v
-               end
+       local points = pop_stats:get_points(0, 120, 640, 200)
+
+       love.graphics.setColor(0, 0, 1)
+       for _, v in ipairs(points) do
+               love.graphics.circle("fill", v[1], v[2], 8)
        end
 
-       local width = 640 / (#stored_fitnesses)
+       love.graphics.pop()
+end
 
-       love.graphics.setColor(0, 0, 1)
-       for i, v in ipairs(stored_fitnesses) do
-               if v < 0 then
-                       v = 0
-               end
-               love.graphics.circle("fill", (i - 1) * width, 300 - v * 200 / highest, 8)
+local function plot_generation(x, y, scale)
+       love.graphics.push()
+       love.graphics.translate(x, y)
+       love.graphics.scale(scale, scale)
+
+       love.graphics.setColor(0, 0, 0, 0.4)
+       love.graphics.rectangle("fill", -20, -20, 680, 340)
+
+       love.graphics.setFont(fitness_font)
+       love.graphics.setColor(CONF.FONT_COLOR)
+
+       love.graphics.printf("Fitness over Genome", 0, 0, 640, "left")
+
+       local points = gen_stats:get_points(0, 60, 640, 260)
+
+       love.graphics.setColor(1, 0, 0)
+       for _, v in ipairs(points) do
+               love.graphics.circle("fill", v[1], v[2], 8)
        end
 
        love.graphics.pop()
@@ -182,5 +199,6 @@ function love.draw()
                draw_network(pop.genomes[pop.current_genome].network, 1200 - 350, 32, 1 / 2)
        end
 
-       --plot_fitness(1200 - 350, 352, 1 / 2)
+       plot_fitness(1200 - 350, 352, 1 / 2)
+       plot_generation(1200 - 350, 600, 1 / 2)
 end
diff --git a/src/stats.lua b/src/stats.lua
new file mode 100644 (file)
index 0000000..907b34c
--- /dev/null
@@ -0,0 +1,96 @@
+
+local Stats = {}
+local Stats_mt = { __index = Stats }
+
+function Stats.new(data)
+       data = data or {}
+
+       local o = {
+               data = data;
+
+               max = 0;
+               min = 0;
+               rng = 0;
+               avg = 0;
+               stddev = 0;
+       }
+
+       setmetatable(o, Stats_mt)
+       return o
+end
+
+function Stats:add_point(d)
+       table.insert(self.data, d)
+end
+
+function Stats:clear()
+       self.data = {}
+end
+
+function Stats:calculate()
+       if #self.data == 0 then
+               self.max = 0
+               self.min = 0
+               self.avg = 0
+               self.stddev = 0
+               self.rng = 0
+               return
+       end
+
+       local sum = 0
+       self.max = nil
+       self.min = nil
+       for _, v in ipairs(self.data) do
+               if self.max == nil then
+                       self.max = v
+               else
+                       if v > self.max then
+                               self.max = v
+                       end
+               end
+
+               if self.min == nil then
+                       self.min = v
+               else
+                       if v < self.min then
+                               self.min = v
+                       end
+               end
+
+               sum = sum + v
+       end
+
+       self.avg = sum / #self.data
+       self.rng = self.max - self.min
+
+       local diff_sum = 0
+       for _, v in ipairs(self.data) do
+               diff_sum = diff_sum + (self.avg - v) ^ 2
+       end
+       self.stddev = diff_sum / #self.data
+end
+
+function Stats:get_points(x, y, w, h)
+       self:calculate()
+
+       local low = y + h
+
+       local delta
+       if #self.data == 1 then
+               delta = 0
+       else
+               delta = w / (#self.data - 1)
+       end
+
+       local points = {}
+
+       for i, v in ipairs(self.data) do
+               table.insert(points, { (i - 1) * delta + x, low - h * (v / self.rng) })
+       end
+
+       return points
+end
+
+return {
+       Stats = Stats;
+}
index 2c8b14df0fd12d721aa3101b1340eb24b4cfee5d..afbe84123203f994db6cfcd77a82f28e8a2002de 100644 (file)
@@ -55,7 +55,7 @@ function Trainer:get_inputs()
        return inputs
 end
 
-function Trainer:after_inputs(inputs, dt)
+function Trainer:after_inputs(inputs, dt, _, gen_stats)
        -- Make sure the player is considered alive at the start of every turn
        self.player.alive = true
 
@@ -93,6 +93,7 @@ function Trainer:after_inputs(inputs, dt)
                        fitness = fitness + CONF.POINTS_PER_ROUND_END
                        self.world:next_round()
                else
+                       gen_stats:add_point(self.population.genomes[self.population.current_genome].fitness)
                        self.world:reset()
                end
 
@@ -102,18 +103,21 @@ function Trainer:after_inputs(inputs, dt)
        return fitness, self.player.alive
 end
 
-function Trainer:pre_evolution(avg, high, _)
+function Trainer:pre_evolution(avg, high, _, stats, _)
+       stats:add_point(avg)
+       self.population:save(CONF.SAVE_FILE)
+
        print("FINISHED GENERATION: " .. self.population.generation .. " | Stats: ")
        print("        Average: " .. tostring(avg))
        print("        Highest: " .. tostring(high))
        print("--------------------------------------------------------------------")
 end
 
-function Trainer:post_evolution(_)
-       self.population:save(CONF.SAVE_FILE)
+function Trainer:post_evolution(_, _, gen_stats)
+       gen_stats:clear()
 end
 
-function Trainer:update(dt)
+function Trainer:update(...)
        local inputs = self:get_inputs()
 
        for _ = 1, self.speed do
@@ -122,7 +126,7 @@ function Trainer:update(dt)
                        inputs,
                        self.after_inputs_func,
                        { self.pre_evolution_func, self.post_evolution_func },
-                       dt
+                       ...
                )
        end
 end