Cross Entropy Method (CEM)¶
Paper |
The cross-entropy method: A unified approach to Monte Carlo simulation, randomized optimization and machine learning [1] |
Framework(s) |
|
API Reference |
|
Code |
Cross Entropy Method (CEM) works by iteratively optimizing a gaussian distribution of policy.
In each epoch, CEM does the following:
Sample n_samples policies from a gaussian distribution of mean cur_mean and std cur_std.
Collect episodes for each policy.
Update cur_mean and cur_std by doing Maximum Likelihood Estimation over the n_best top policies in terms of return.
Examples¶
NumPy¶
#!/usr/bin/env python3
"""This is an example to train a task with Cross Entropy Method.
Here it runs CartPole-v1 environment with 100 epoches.
Results:
AverageReturn: 100
RiseTime: epoch 8
"""
from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.algos import CEM
from garage.tf.policies import CategoricalMLPPolicy
from garage.trainer import TFTrainer
@wrap_experiment
def cem_cartpole(ctxt=None, seed=1):
"""Train CEM with Cartpole-v1 environment.
Args:
ctxt (garage.experiment.ExperimentContext): The experiment
configuration used by Trainer to create the snapshotter.
seed (int): Used to seed the random number generator to produce
determinism.
"""
set_seed(seed)
with TFTrainer(snapshot_config=ctxt) as trainer:
env = GymEnv('CartPole-v1')
policy = CategoricalMLPPolicy(name='policy',
env_spec=env.spec,
hidden_sizes=(32, 32))
n_samples = 20
algo = CEM(env_spec=env.spec,
policy=policy,
best_frac=0.05,
n_samples=n_samples)
trainer.setup(algo, env)
trainer.train(n_epochs=100, batch_size=1000)
cem_cartpole(seed=1)