Here we program an agent that learns to navigate a map, using value iteration to find the best policy. That is, by repeatedly running the experiment, the agent gradually learns which actions give the highest reward from each state.
Some core concepts of Markov Decision Processes (MDPs):
Let's say we have a state s_i
, reward R_i
for that state, policy pi
(that outputs an action pi(s)
for that state). There's also a parameter phi
called the discount factor. Then the expected reward when starting in that state is...
V(s_i) = R_i + phi \sum_{j} P_{s_i, pi(s_i)}(s_i) V(s_j).
Assuming we know R_i
and P_{s_i, a}
(for all actions a
) perfectly, we can solve the resulting system of equations to determine V(s_i)
for all i
.
Our goal is to find an optimal policy pi^*
. For this we use an algorithm called value iteration.
V(s_i) = 0
for all i
.V(s_i) := R_i + argmax_{a} phi \sum_{j} P_{s_i, a}(s_j) V(s_j)
.This is guaranteed to find the optimal policy. If the state transition probabilities and/or rewards are unknown, then we can repeatedly run simulations to estimate them and update our policy each time.
import itertools
import numpy as np
import collections
import random
import matplotlib.pyplot as plt
from PIL import Image
Define the "game": states, rewards, and possible actions. Goal is to get to the top right corner. There's also a bad position adjacent to the goal that we should avoid. This is based on an example in Andrew Ng's lectures, which in turn was taken from Norvig's AI book.
W, H = 4, 3
STATES = list(itertools.product(range(H), range(W)))
END_STATES = set([
(0, 3),
(1, 3)
])
print("States:")
print(" ", STATES)
REWARDS = {}
REWARDS[(0, 3)] = 1
REWARDS[(1, 3)] = -1
for s in STATES:
if s not in END_STATES:
# Slightly negative reward for non-goal states to
# discourage dilly-dallying.
REWARDS[s] = -.1
print("Reward map:")
print(" ", REWARDS)
class Action:
LEFT = "left"
RIGHT = "right"
UP = "up"
DOWN = "down"
ACTIONS = (Action.LEFT, Action.RIGHT, Action.UP, Action.DOWN)
print("Actions:")
print(" ", ACTIONS)
States: [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3)] Reward map: {(0, 3): 1, (1, 3): -1, (0, 0): -0.1, (0, 1): -0.1, (0, 2): -0.1, (1, 0): -0.1, (1, 1): -0.1, (1, 2): -0.1, (2, 0): -0.1, (2, 1): -0.1, (2, 2): -0.1, (2, 3): -0.1} Actions: ('left', 'right', 'up', 'down')
And the game logic.
def transition(s, a):
next_states, cumulative_distr = get_possible_next_states(s, a)
return random.choices(next_states,
cum_weights=cumulative_distr)[0]
def get_possible_next_states(s, a, cumulative=True):
if a == Action.LEFT:
outcomes = [Action.UP, Action.DOWN, Action.LEFT]
elif a == Action.RIGHT:
outcomes = [Action.UP, Action.DOWN, Action.RIGHT]
elif a == Action.UP:
outcomes = [Action.LEFT, Action.RIGHT, Action.UP]
else:
outcomes = [Action.LEFT, Action.RIGHT, Action.DOWN]
# BUG: multiple actions may lead to the same outcome, in which
# case we may need to de-duplicate this list.
# (It's fine for sampling, but bug-prone).
return ([apply_action(s, outcome) for outcome in outcomes],
(.125, .25, 1.) if cumulative else (.125, .125, .75))
def apply_action(s, a):
i, j = s
if a == Action.LEFT: j -= 1
elif a == Action.RIGHT: j += 1
elif a == Action.UP: i -= 1
else: i += 1
return (max(0, min(i, H-1)), max(0, min(j, W-1)))
def get_reward(s):
return REWARDS[s] + np.random.normal(scale=.001)
def incorporate_stats(s_current, R_current, a, s_old, reward_estimate, state_count, transition_estimate):
reward_estimate[s_current] += R_current
state_count[s_current] += 1
pair = (s_old, a)
if pair not in transition_estimate:
transition_estimate[pair] = collections.defaultdict(int)
transition_estimate[pair][s_current] += 1
def simulate(policy_fn, s0=(2,2), max_steps=100, dropoff=.99):
s_current = s0
R = get_reward(s0)
steps = 0
reward_estimate = collections.defaultdict(float)
reward_estimate[s_current] = R
state_count = collections.defaultdict(int)
state_count[s0] = 1
transition_estimate = {}
curr_dropoff = dropoff
while s_current not in END_STATES and steps < max_steps:
a = policy_fn(s_current)
s_old = s_current
s_current = transition(s_current, a)
R_current = get_reward(s_current)
R += curr_dropoff*R_current
curr_dropoff *= dropoff
# Update estimates.
incorporate_stats(s_current, R_current, a, s_old, reward_estimate, state_count, transition_estimate)
steps += 1
end_states = set([s_current]) if s_current in END_STATES else set()
# Return not only the final reward, but estimates of the rewards and
# transition probabilities.
return R, reward_estimate, state_count, transition_estimate, end_states
Test it out with a policy that picks a random action each time. (i.e. it's not a fixed policy).
def random_policy(s):
return random.choice(ACTIONS)
R, reward_estimate, state_count, transition_estimate, end_states = simulate(random_policy)
print("Reward:", R)
print("Reward estimate (by state):")
for s in STATES:
print(" ", s, reward_estimate[s]/state_count[s]
if state_count[s] > 0
else "Never reached")
print("State counts:")
total_count = 0
for s, c in state_count.items():
total_count += c
print(" ", s, c)
print("Total count:", total_count)
print("Transition estimate:")
for pair in itertools.product(STATES, ACTIONS):
if pair in transition_estimate:
print(" ", pair, transition_estimate[pair])
print("End states:", end_states)
Reward: -0.15279179826506206 Reward estimate (by state): (0, 0) Never reached (0, 1) -0.09827322327816897 (0, 2) -0.10052451242935982 (0, 3) 1.0003209901458308 (1, 0) Never reached (1, 1) -0.09950149171111555 (1, 2) -0.10055114502555784 (1, 3) Never reached (2, 0) Never reached (2, 1) -0.10063272594268527 (2, 2) -0.10071028555084191 (2, 3) Never reached State counts: (2, 2) 2 (1, 2) 3 (1, 1) 2 (2, 1) 1 (0, 1) 1 (0, 2) 2 (0, 3) 1 (0, 0) 0 (1, 0) 0 (1, 3) 0 (2, 0) 0 (2, 3) 0 Total count: 12 Transition estimate: ((0, 1), 'right') defaultdict(<class 'int'>, {(0, 2): 1}) ((0, 2), 'left') defaultdict(<class 'int'>, {(1, 2): 1}) ((0, 2), 'right') defaultdict(<class 'int'>, {(0, 3): 1}) ((1, 1), 'left') defaultdict(<class 'int'>, {(2, 1): 1}) ((1, 1), 'up') defaultdict(<class 'int'>, {(0, 1): 1}) ((1, 2), 'up') defaultdict(<class 'int'>, {(0, 2): 1}) ((1, 2), 'down') defaultdict(<class 'int'>, {(1, 1): 1, (2, 2): 1}) ((2, 1), 'left') defaultdict(<class 'int'>, {(1, 1): 1}) ((2, 2), 'left') defaultdict(<class 'int'>, {(1, 2): 1}) ((2, 2), 'up') defaultdict(<class 'int'>, {(1, 2): 1}) End states: {(0, 3)}
Visualising the map.
board = np.array([
[REWARDS[(i, j)] for j in range(W)]
for i in range(H)])
plt.matshow(board, cmap='hot')
plt.colorbar()
plt.xticks(ticks=[x-0.5 for x in range(W)])
plt.yticks(ticks=[y-0.5 for y in range(H)])
plt.grid(c="black", lw=2)
f = plt.gca()
f.axes.xaxis.set_ticklabels([])
f.axes.yaxis.set_ticklabels([])
plt.text(1.7, 2.05, "START", weight="bold")
Text(1.7, 2.05, 'START')
First, supposing we've calculated the expected value from each state, we can construct the optimal policy. Let's also implement value iteration to actually calculate the expected values.
def get_transition_prob(source, target, action, transition_probs):
# Not exactly happy with how I've modelled the transition
# probabilities, it's too ad-hoc. And assumes that there's an
# entry for every state-action pair.
return transition_probs[(source, action)].get(target, 0)
def make_optimal_policy(value_map, transition_probs):
policy_map = make_optimal_policy_map(value_map, transition_probs)
def policy(s):
return policy_map[s]
return policy
def make_optimal_policy_map(value_map, transition_probs):
actions = {}
for s in STATES:
def action_expected_reward(a):
return sum(get_transition_prob(s, new_s, a, transition_probs) * value_map[new_s]
for new_s in STATES)
actions[s] = max(ACTIONS, key=action_expected_reward)
return actions
def value_iteration(reward_map, transition_probs, dropoff,
initial_value_map=None, eps=.0001):
value_map = (initial_value_map.copy()
if initial_value_map
else dict((s, 0) for s in STATES))
while True:
old_values = value_map.copy()
diff = 0
# SYNCHRONOUS update. We update all the expected values in lockstep, rather
# than updating them one at a time and using the new values as they are
# available (which would be asynchronous).
for s in STATES:
value_map[s] = (reward_map[s]
+ max(dropoff*sum(get_transition_prob(s, t, a, transition_probs)*old_values[t]
for t in STATES)
for a in ACTIONS))
diff += abs(value_map[s] - old_values[s])
if diff < eps:
break
return value_map
Initialise our expected values and estimates of rewards/transition probs.
value_map = {}
reward_map = {}
for s in STATES:
value_map[s] = 0
reward_map[s] = 0
transition_probs = {}
for pair in itertools.product(STATES, ACTIONS):
# If there were more states we wouldn't be computing
# all the transition probabilities like this...
distr = {}
transition_probs[pair] = distr
for target in STATES:
distr[target] = 1/len(STATES)
Now the main loop. In each iteration, run some simulations, then use the refined estimates of rewards / transition probabilities to update the expected values (via value iteration).
dropoff = .99
reward_est, state_count, transition_est = None, None, None
policy = make_optimal_policy(value_map, transition_probs)
rounds = 100
sims_per_round = 100
for i in range(rounds):
R_sum = 0
for _ in range(sims_per_round):
(R, sim_reward_est, sim_state_count,
sim_transition_est, sim_end_states) = simulate(
policy, dropoff=dropoff)
R_sum += R
# Incorporate the statistics from the simulation.
if reward_est is None:
reward_est = sim_reward_est
state_count = sim_state_count
transition_est = sim_transition_est
else:
for s, r in sim_reward_est.items(): reward_est[s] += r
for s, c in sim_state_count.items(): state_count[s] += c
for pair, sim_distr in sim_transition_est.items():
if pair not in transition_est:
transition_est[pair] = sim_distr
else:
distr = transition_est[pair]
for s, c in sim_distr.items():
distr[s] += c
# Update estimates.
for s in STATES:
c = state_count[s]
# Encourage to visit unknown states.
reward_map[s] = 1 if c == 0 else reward_est[s]/c
# Never goes anywhere, absorbing state.
for pair in itertools.product(sim_end_states, ACTIONS):
distr = transition_probs[pair]
for t in distr:
distr[t] = 0
for pair, counts in transition_est.items():
total_count = sum(counts.values())
distr = transition_probs[pair]
s, a = pair
if total_count > 0:
# Relying on the fact that we initialized this map so that
# all the states have an entry.
for t in distr:
# For each target state, we update
distr[t] = 0 if s not in counts else counts[t]/total_count
# Update values.
value_map = value_iteration(reward_map, transition_probs, dropoff, initial_value_map=value_map)
# Update policy.
policy = make_optimal_policy(value_map, transition_probs)
# Print out the average score in each round, should indicate whether the
# policy is improving as we gather more data.
print(f"### Round {i}: {R_sum/sims_per_round}")
print(" Reward map --", reward_map)
print(" Value estimates --", value_map)
print(" Policy --", make_optimal_policy_map(value_map, transition_probs))
print()
Visualising the policy we've derived.
board = np.array([
[REWARDS[(i, j)] for j in range(W)]
for i in range(H)])
plt.matshow(board, cmap='spring')
plt.colorbar()
plt.xticks(ticks=[x-0.5 for x in range(W)])
plt.yticks(ticks=[y-0.5 for y in range(H)])
plt.grid(c="black", lw=2)
f = plt.gca()
f.axes.xaxis.set_ticklabels([])
f.axes.yaxis.set_ticklabels([])
for s in STATES:
if s in END_STATES:
continue
i, j = s
dx, dy = 0, 0
a = policy(s)
if a == Action.DOWN: dy = .25
elif a == Action.UP: dy = -.25
elif a == Action.RIGHT: dx = .25
else: dx = -.25
plt.arrow(j-dx/2, i-dy/2, dx, dy, length_includes_head=True,
head_width=0.1, fc="black")
As stated previously, for a given policy we can determine its exact expected value by solving a system of linear equations. The i
-th state gives the following equation:
V(s_i) = R_i + phi \sum_{j} P_{s_i, pi(s_i)}(s_j) V(s_j),
\sum_{j} V(s_j) (phi P_{s_i, pi(s_i)}(s_j) - [i=j]) = -R_i,
where phi
is the dropoff, and [i=j]
is 1 if i=j
and 0 otherwise. For this, let's use the true value transition probabilities & rewards rather than the estimates, though it might be interesting to see the difference between the estimated expected values and the true expected values.
def eval_policy(policy, dropoff):
state_to_index = {}
i = 0
for s in STATES:
# The absorbing states have constant value, the system will end
# up unsolvable if we include them.
if s not in END_STATES:
state_to_index[s] = i
i += 1
N = len(state_to_index)
M = np.zeros((N,N))
R = np.zeros(N)
for s, i in state_to_index.items():
R[i] = -REWARDS[s]
for t, p in zip(*get_possible_next_states(s, policy(s), cumulative=False)):
if t in END_STATES:
# Expected value of an absorbing/end state is a constant term.
R[i] -= dropoff*p*REWARDS[t]
else:
M[i, state_to_index[t]] = dropoff*p - (1 if s==t else 0)
print(state_to_index)
print(R)
print(M)
soln = np.linalg.solve(M, R)
print(soln)
print(np.dot(M, soln))
V = dict([
(s, REWARDS[s] if s in END_STATES else soln[state_to_index[s]])
for s in STATES
])
return V
Hmm, this is giving nonsensical values (e.g. expected value of >1 for many states). And I'm not that bothered to debug it.
with np.printoptions(edgeitems=30, linewidth=100000,
formatter=dict(float=lambda x: "%.5g" % x)):
V = eval_policy(policy, dropoff)
print("Expected values by state!? --")
print(" ", V)
{(0, 0): 0, (0, 1): 1, (0, 2): 2, (1, 0): 3, (1, 1): 4, (1, 2): 5, (2, 0): 6, (2, 1): 7, (2, 2): 8, (2, 3): 9} [0.1 0.1 -0.6425 0.1 0.1 0.1 0.1 0.1 0.1 0.22375] [[-0.87625 0.7425 0 0.12375 0 0 0 0 0 0] [0 -0.87625 0.7425 0 0.12375 0 0 0 0 0] [0 0 -0.87625 0 0 0.12375 0 0 0 0] [0.7425 0 0 -0.87625 0.12375 0 0 0 0 0] [0 0.12375 0 0.7425 0 0 0 0.12375 0 0] [0 0 0.12375 0 0.7425 0 0 0 0.12375 0] [0 0 0 0.7425 0 0 -0.87625 0.12375 0 0] [0 0 0 0 0.7425 0 0.12375 0 0.12375 0] [0 0 0 0 0 0.7425 0 0.12375 0 0.12375] [0 0 0 0 0 0 0 0 0.7425 -0.87625]] [-0.78325 -0.57226 0.080818 -1.3044 -3.7289 -4.6197 0.080818 9.207 23.1 19.319] [0.1 0.1 -0.6425 0.1 0.1 0.1 0.1 0.1 0.1 0.22375] Expected values by state!? -- {(0, 0): -0.7832523792455475, (0, 1): -0.5722570663704574, (0, 2): 0.08081804503662932, (0, 3): 1, (1, 0): -1.3044365699704759, (1, 1): -3.7288699219944283, (1, 2): -4.619662125548715, (1, 3): -1, (2, 0): 0.08081804503662698, (2, 1): 9.20695729427412, (2, 2): 23.100482295010746, (2, 3): 19.319096267098978}