local Gen = require "src.genetics"
require "src.data"
local Trainer = (require "src.trainer").Trainer
+local Stats = (require "src.stats").Stats
local World = world_mod.World
local Wall = world_mod.Wall
local input
local pop
local trainer
+local pop_stats
+local gen_stats
local ui_font
local fitness_font
-local stored_fitnesses = {}
-
function love.load()
math.randomseed(os.time())
pop = Population.load(CONF.LOAD_FILE)
end
+ pop_stats = Stats.new()
+ gen_stats = Stats.new()
+
trainer = Trainer.new(pop, world, input)
trainer:initialize_training()
trainer:change_speed(1)
end
- trainer:update(dt)
+ trainer:update(dt, pop_stats, gen_stats)
--world:update(dt, input)
end
love.graphics.printf("Average fitness: " .. math.floor(pop.avg_fitness), 0, 0, 640, "left")
love.graphics.printf("Highest fitness: " .. math.floor(pop.high_fitness), 0, 32, 640, "left")
- local highest = 0
- for _, v in ipairs(stored_fitnesses) do
- if v > highest then
- highest = v
- end
+ local points = pop_stats:get_points(0, 120, 640, 200)
+
+ love.graphics.setColor(0, 0, 1)
+ for _, v in ipairs(points) do
+ love.graphics.circle("fill", v[1], v[2], 8)
end
- local width = 640 / (#stored_fitnesses)
+ love.graphics.pop()
+end
- love.graphics.setColor(0, 0, 1)
- for i, v in ipairs(stored_fitnesses) do
- if v < 0 then
- v = 0
- end
- love.graphics.circle("fill", (i - 1) * width, 300 - v * 200 / highest, 8)
+local function plot_generation(x, y, scale)
+ love.graphics.push()
+ love.graphics.translate(x, y)
+ love.graphics.scale(scale, scale)
+
+ love.graphics.setColor(0, 0, 0, 0.4)
+ love.graphics.rectangle("fill", -20, -20, 680, 340)
+
+ love.graphics.setFont(fitness_font)
+ love.graphics.setColor(CONF.FONT_COLOR)
+
+ love.graphics.printf("Fitness over Genome", 0, 0, 640, "left")
+
+ local points = gen_stats:get_points(0, 60, 640, 260)
+
+ love.graphics.setColor(1, 0, 0)
+ for _, v in ipairs(points) do
+ love.graphics.circle("fill", v[1], v[2], 8)
end
love.graphics.pop()
draw_network(pop.genomes[pop.current_genome].network, 1200 - 350, 32, 1 / 2)
end
- --plot_fitness(1200 - 350, 352, 1 / 2)
+ plot_fitness(1200 - 350, 352, 1 / 2)
+ plot_generation(1200 - 350, 600, 1 / 2)
end
--- /dev/null
+
+local Stats = {}
+local Stats_mt = { __index = Stats }
+
+function Stats.new(data)
+ data = data or {}
+
+ local o = {
+ data = data;
+
+ max = 0;
+ min = 0;
+ rng = 0;
+ avg = 0;
+ stddev = 0;
+ }
+
+ setmetatable(o, Stats_mt)
+ return o
+end
+
+function Stats:add_point(d)
+ table.insert(self.data, d)
+end
+
+function Stats:clear()
+ self.data = {}
+end
+
+function Stats:calculate()
+ if #self.data == 0 then
+ self.max = 0
+ self.min = 0
+ self.avg = 0
+ self.stddev = 0
+ self.rng = 0
+ return
+ end
+
+ local sum = 0
+ self.max = nil
+ self.min = nil
+ for _, v in ipairs(self.data) do
+ if self.max == nil then
+ self.max = v
+ else
+ if v > self.max then
+ self.max = v
+ end
+ end
+
+ if self.min == nil then
+ self.min = v
+ else
+ if v < self.min then
+ self.min = v
+ end
+ end
+
+ sum = sum + v
+ end
+
+ self.avg = sum / #self.data
+ self.rng = self.max - self.min
+
+ local diff_sum = 0
+ for _, v in ipairs(self.data) do
+ diff_sum = diff_sum + (self.avg - v) ^ 2
+ end
+ self.stddev = diff_sum / #self.data
+end
+
+function Stats:get_points(x, y, w, h)
+ self:calculate()
+
+ local low = y + h
+
+ local delta
+ if #self.data == 1 then
+ delta = 0
+ else
+ delta = w / (#self.data - 1)
+ end
+
+ local points = {}
+
+ for i, v in ipairs(self.data) do
+ table.insert(points, { (i - 1) * delta + x, low - h * (v / self.rng) })
+ end
+
+ return points
+end
+
+return {
+ Stats = Stats;
+}
return inputs
end
-function Trainer:after_inputs(inputs, dt)
+function Trainer:after_inputs(inputs, dt, _, gen_stats)
-- Make sure the player is considered alive at the start of every turn
self.player.alive = true
fitness = fitness + CONF.POINTS_PER_ROUND_END
self.world:next_round()
else
+ gen_stats:add_point(self.population.genomes[self.population.current_genome].fitness)
self.world:reset()
end
return fitness, self.player.alive
end
-function Trainer:pre_evolution(avg, high, _)
+function Trainer:pre_evolution(avg, high, _, stats, _)
+ stats:add_point(avg)
+ self.population:save(CONF.SAVE_FILE)
+
print("FINISHED GENERATION: " .. self.population.generation .. " | Stats: ")
print(" Average: " .. tostring(avg))
print(" Highest: " .. tostring(high))
print("--------------------------------------------------------------------")
end
-function Trainer:post_evolution(_)
- self.population:save(CONF.SAVE_FILE)
+function Trainer:post_evolution(_, _, gen_stats)
+ gen_stats:clear()
end
-function Trainer:update(dt)
+function Trainer:update(...)
local inputs = self:get_inputs()
for _ = 1, self.speed do
inputs,
self.after_inputs_func,
{ self.pre_evolution_func, self.post_evolution_func },
- dt
+ ...
)
end
end