Skip to content

Commit

Permalink
shainet training
Browse files Browse the repository at this point in the history
  • Loading branch information
fdocr committed Oct 21, 2023
1 parent d9fc013 commit 4b3378d
Show file tree
Hide file tree
Showing 11 changed files with 1,111 additions and 119 deletions.
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,12 @@
app
worker
bundle
mass_nn_training
script
train_network
.env
.env.test
.env.test
*.csv
/nn_training/*.log
/nn_training/*.nn
/nn_training/old/
Empty file added nn_training/.keep
Empty file.
3 changes: 2 additions & 1 deletion spec/server_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ SUPPORTED_STRATEGIES = [
"blast_random_valid",
"chase_closest_food",
"chase_random_food",
"cautious_carol"
"cautious_carol",
"cc_nn"
]

describe "Crystal Snake Battlesnake endpoints for all supported strategies" do
Expand Down
27 changes: 27 additions & 0 deletions src/battle_snake/context.cr
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,31 @@ class BattleSnake::Context
board.snakes.delete(snake)
end
end

def to_nn_input
turn_data = Array.new(121, 0.as(Int32 | String))
board.food.each do |food|
offset = food.x + (food.y * 11)
turn_data[offset] = 10
end

snake_counter = 100
you.body.each do |point|
offset = point.x + (point.y * 11)
turn_data[offset] = snake_counter
snake_counter += 1
end

enemies.each_with_index do |snake, i|
snake_counter = 200 + (i * 100)
snake.body.each do |point|
offset = point.x + (point.y * 11)
turn_data[offset] = snake_counter
snake_counter += 1
end
snake.body
end

turn_data.map(&.to_i32)
end
end
1 change: 1 addition & 0 deletions src/cc_nn.nn

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions src/mass_nn_training.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
require "dotenv"
Dotenv.load if File.exists?(".env")
require "../config/**"

# Shainet
layer_neurons_opts = [35, 60, 75]
learning_rate_opts = [0.005, 0.2, 0.7]
momentum_opts = [0.05, 0.1, 0.3]
error_threshold = 0.0001
option_permutations = [] of Hash(Symbol, Float64)
27.times do |i|
hash = Hash(Symbol, Float64).new
hash[:neurons] = layer_neurons_opts[i % 3]
hash[:learning_rate] = learning_rate_opts[(i / 3).to_i % 3]
hash[:momentum] = momentum_opts[(i / 9).to_i % 3]
option_permutations << hash
end

Log.info { "#{option_permutations.size} options" }
options_channel = Channel(Hash(Symbol, Float64)).new
complete_channel = Channel(Int32).new

5.times do |i|
spawn do
loop do
[200, 1000, 15000].shuffle.each do |epochs|
options = options_channel.receive
log_name = "./nn_training/cc_nn_#{options[:neurons].to_i}n_#{options[:learning_rate]}l_#{options[:momentum]}m_#{epochs}e_#{error_threshold}t.log"
command = "./train_network -n #{options[:neurons].to_i} -l #{options[:learning_rate]} -m #{options[:momentum]} -e #{epochs} -t #{error_threshold} | tee #{log_name}"
system command
end

complete_channel.send(i)
end
end
end

# Delivery permutations
spawn { option_permutations.shuffle.each { |options| options_channel.send(options) } }

# Wait for permutations processing
27.times do |i|
res = complete_channel.receive
Log.info { "[root] -> Finished processing from #{res}" }
end
4 changes: 2 additions & 2 deletions src/sam.cr
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ load_dependencies "jennifer"
task "dev" do
sentry = Sentry::ProcessRunner.new(
display_name: "App",
build_command: "crystal build ./src/app.cr",
build_command: "crystal build ./src/app.cr -Dpreview_mt",
run_command: "./app",
run_args: ["-p", "8080"],
files: [ "./src/**/*", "./config/*.cr" ]
Expand All @@ -25,7 +25,7 @@ task "test" do
end

task "script" do
system "crystal run ./src/script.cr"
system "crystal run ./src/script.cr -Dpreview_mt"
end

Sam.help
Loading

0 comments on commit 4b3378d

Please sign in to comment.