LOG_PERIOD = 1000
PRINT_PERIOD = 100
LEARNING_RATE = 0.001
NUM_ITERATIONS = 100_000
graph = Graphic(LOG_PERIOD)
tf_env = TFPyEnvironment(RandomTicTacToeEnvironment())
q_net = QNetwork(
tf_env.observation_spec(),
tf_env.action_spec(),
fc_layer_params=(100,)
)
train_step_counter = tf.Variable(0)
agent = DqnAgent(
time_step_spec=tf_env.time_step_spec(),
action_spec=tf_env.action_spec(),
q_network=q_net,
optimizer=Adam(learning_rate=LEARNING_RATE),
td_errors_loss_fn=common.element_wise_squared_loss,
epsilon_greedy=0.1,
train_step_counter=train_step_counter
)
agent.initialize()
eval_policy = agent.policy
collect_policy = agent.collect_policy
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=tf_env.batch_size,
max_length=1000
)
collect_driver = dynamic_step_driver.DynamicStepDriver(
tf_env,
collect_policy,
observers=[replay_buffer.add_batch],
num_steps=10
)
collect_driver.run = common.function(collect_driver.run)
agent.train = common.function(agent.train)
initial_collect_policy = random_tf_policy.RandomTFPolicy(
tf_env.time_step_spec(),
tf_env.action_spec()
)
dataset = replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=8,
num_steps=2,
single_deterministic_pass=False
).prefetch(3)
iterator = iter(dataset)
dynamic_step_driver.DynamicStepDriver(
tf_env,
initial_collect_policy,
observers=[replay_buffer.add_batch],
num_steps=10
)
time_step = tf_env.reset()
for _ in np.arange(NUM_ITERATIONS+1):
time_step, _ = collect_driver.run(time_step)
experience, _ = next(iterator)
step = agent.train_step_counter.numpy()
train_loss = agent.train(experience).loss
if step % PRINT_PERIOD == 0:
print('step = {0}: loss = {1}'.format(step, train_loss))
for reward in tf.reshape(experience.reward, [-1]):
graph.check(step, reward)