From: Brendan Hansen Date: Fri, 15 Mar 2019 18:55:40 +0000 (-0500) Subject: Added statistics X-Git-Url: https://git.brendanfh.com/?a=commitdiff_plain;h=4b0c316c50ba941307a5eddfcb542412947b4c25;p=genetic-shooter.git Added statistics --- diff --git a/conf.lua b/conf.lua index 3cd3078..d0e6318 100644 --- a/conf.lua +++ b/conf.lua @@ -44,8 +44,8 @@ local KEYMAP = { return { -- GENERAL PROPERTIES - LOAD_FILE = "./saved/TERM_1_GEN_14"; - SAVE_FILE = "./saved/TERM"; + LOAD_FILE = ""; + SAVE_FILE = "./saved/DELETE"; WINDOW_WIDTH = WINDOW_WIDTH; WINDOW_HEIGHT = WINDOW_HEIGHT; diff --git a/main.lua b/main.lua index cea19bb..ebd675a 100644 --- a/main.lua +++ b/main.lua @@ -5,6 +5,7 @@ local Input = require "src.input" local Gen = require "src.genetics" require "src.data" local Trainer = (require "src.trainer").Trainer +local Stats = (require "src.stats").Stats local World = world_mod.World local Wall = world_mod.Wall @@ -14,12 +15,12 @@ local world local input local pop local trainer +local pop_stats +local gen_stats local ui_font local fitness_font -local stored_fitnesses = {} - function love.load() math.randomseed(os.time()) @@ -43,6 +44,9 @@ function love.load() pop = Population.load(CONF.LOAD_FILE) end + pop_stats = Stats.new() + gen_stats = Stats.new() + trainer = Trainer.new(pop, world, input) trainer:initialize_training() @@ -72,7 +76,7 @@ function love.update(dt) trainer:change_speed(1) end - trainer:update(dt) + trainer:update(dt, pop_stats, gen_stats) --world:update(dt, input) end @@ -90,21 +94,34 @@ local function plot_fitness(x, y, scale) love.graphics.printf("Average fitness: " .. math.floor(pop.avg_fitness), 0, 0, 640, "left") love.graphics.printf("Highest fitness: " .. math.floor(pop.high_fitness), 0, 32, 640, "left") - local highest = 0 - for _, v in ipairs(stored_fitnesses) do - if v > highest then - highest = v - end + local points = pop_stats:get_points(0, 120, 640, 200) + + love.graphics.setColor(0, 0, 1) + for _, v in ipairs(points) do + love.graphics.circle("fill", v[1], v[2], 8) end - local width = 640 / (#stored_fitnesses) + love.graphics.pop() +end - love.graphics.setColor(0, 0, 1) - for i, v in ipairs(stored_fitnesses) do - if v < 0 then - v = 0 - end - love.graphics.circle("fill", (i - 1) * width, 300 - v * 200 / highest, 8) +local function plot_generation(x, y, scale) + love.graphics.push() + love.graphics.translate(x, y) + love.graphics.scale(scale, scale) + + love.graphics.setColor(0, 0, 0, 0.4) + love.graphics.rectangle("fill", -20, -20, 680, 340) + + love.graphics.setFont(fitness_font) + love.graphics.setColor(CONF.FONT_COLOR) + + love.graphics.printf("Fitness over Genome", 0, 0, 640, "left") + + local points = gen_stats:get_points(0, 60, 640, 260) + + love.graphics.setColor(1, 0, 0) + for _, v in ipairs(points) do + love.graphics.circle("fill", v[1], v[2], 8) end love.graphics.pop() @@ -182,5 +199,6 @@ function love.draw() draw_network(pop.genomes[pop.current_genome].network, 1200 - 350, 32, 1 / 2) end - --plot_fitness(1200 - 350, 352, 1 / 2) + plot_fitness(1200 - 350, 352, 1 / 2) + plot_generation(1200 - 350, 600, 1 / 2) end diff --git a/src/stats.lua b/src/stats.lua new file mode 100644 index 0000000..907b34c --- /dev/null +++ b/src/stats.lua @@ -0,0 +1,96 @@ + +local Stats = {} +local Stats_mt = { __index = Stats } + +function Stats.new(data) + data = data or {} + + local o = { + data = data; + + max = 0; + min = 0; + rng = 0; + avg = 0; + stddev = 0; + } + + setmetatable(o, Stats_mt) + return o +end + +function Stats:add_point(d) + table.insert(self.data, d) +end + +function Stats:clear() + self.data = {} +end + +function Stats:calculate() + if #self.data == 0 then + self.max = 0 + self.min = 0 + self.avg = 0 + self.stddev = 0 + self.rng = 0 + return + end + + local sum = 0 + self.max = nil + self.min = nil + for _, v in ipairs(self.data) do + if self.max == nil then + self.max = v + else + if v > self.max then + self.max = v + end + end + + if self.min == nil then + self.min = v + else + if v < self.min then + self.min = v + end + end + + sum = sum + v + end + + self.avg = sum / #self.data + self.rng = self.max - self.min + + local diff_sum = 0 + for _, v in ipairs(self.data) do + diff_sum = diff_sum + (self.avg - v) ^ 2 + end + self.stddev = diff_sum / #self.data +end + +function Stats:get_points(x, y, w, h) + self:calculate() + + local low = y + h + + local delta + if #self.data == 1 then + delta = 0 + else + delta = w / (#self.data - 1) + end + + local points = {} + + for i, v in ipairs(self.data) do + table.insert(points, { (i - 1) * delta + x, low - h * (v / self.rng) }) + end + + return points +end + +return { + Stats = Stats; +} diff --git a/src/trainer.lua b/src/trainer.lua index 2c8b14d..afbe841 100644 --- a/src/trainer.lua +++ b/src/trainer.lua @@ -55,7 +55,7 @@ function Trainer:get_inputs() return inputs end -function Trainer:after_inputs(inputs, dt) +function Trainer:after_inputs(inputs, dt, _, gen_stats) -- Make sure the player is considered alive at the start of every turn self.player.alive = true @@ -93,6 +93,7 @@ function Trainer:after_inputs(inputs, dt) fitness = fitness + CONF.POINTS_PER_ROUND_END self.world:next_round() else + gen_stats:add_point(self.population.genomes[self.population.current_genome].fitness) self.world:reset() end @@ -102,18 +103,21 @@ function Trainer:after_inputs(inputs, dt) return fitness, self.player.alive end -function Trainer:pre_evolution(avg, high, _) +function Trainer:pre_evolution(avg, high, _, stats, _) + stats:add_point(avg) + self.population:save(CONF.SAVE_FILE) + print("FINISHED GENERATION: " .. self.population.generation .. " | Stats: ") print(" Average: " .. tostring(avg)) print(" Highest: " .. tostring(high)) print("--------------------------------------------------------------------") end -function Trainer:post_evolution(_) - self.population:save(CONF.SAVE_FILE) +function Trainer:post_evolution(_, _, gen_stats) + gen_stats:clear() end -function Trainer:update(dt) +function Trainer:update(...) local inputs = self:get_inputs() for _ = 1, self.speed do @@ -122,7 +126,7 @@ function Trainer:update(dt) inputs, self.after_inputs_func, { self.pre_evolution_func, self.post_evolution_func }, - dt + ... ) end end