Saturday, June 11, 2005

Simple function optimizaton

This is an optimization technique I learned about a few days ago at a symposium. "Optimization" in this case means finding the input parameter values that result in the best function output, or at least approximating them.

The performance of this implementation isn't as good as the presenters', in terms of the number of iterations it takes to approach the optimum. I'm not sure what the cause of that is, but it may well just be a case of their having optimized their optimizer (two different meanings for the word "optimize" in the same sentence... there oughta be a law.)

In any case, this is a pretty nice algorithm, which converges relatively fast and (perhaps more importantly) is quite simple and understandable. Once you understand it, you'll be able to code it out from memory.

The general idea is that the optimizer maintains a "swarm" of "particles" which move about the space of the function, trying to find the best locations. The rules used to control their movements are (loosely) inspired by those that guide the flocking behavior of birds and fish.

This implementation is perhaps obfuscated by the use of numarray, but the advantages are too great to ignore.

from numarray import array, greater, maximum, minimum, transpose
from numarray.random_array.RandomArray2 import random as randarray

def pso(function, ranges, particle_count = 5, friction = 0.95, confine = True):
Performs a Particle Swarm Optimization on the passed function,
performing one optimization step for each iteration step. The
yielded values are are 2-tuples of (best known value, tuple of
parameters resulting in the best known value).


* function: The function you wish to find the optimal inputs
for. It should be a function of X floating point parameters,
where X = len(ranges). Further, the result should be a floating
point number which can be interpreted as the function quality
with a given set of parameters. If your desired function does
not result in a float or is not parameterized by floats, feel
free to wrap it in a function that does appropriate translation.

* ranges: The bounds within which the exploratory particles will
be created. If confine = True, these ranges also bound the space
to be explored. Each range should be a 2-tuple of (lower bound,
upper bound), and there should be one for each parameter of the

* particle_count: The number of particles to maintain while
optimizing. Fairly small numbers are often good. Try 5 if you're
not sure.

* friction: The amount of 'velocity' that is lost from the system
on each time step. Friction will be clamped to the range [0-1],
and then mutuplied by the current particle velocities. A value
of 1 will tend to cause the optimizer to fail if the boundary
flag is not set to optimize.reflect. If you're not sure what value to
use, try 0.95

* confine: May be set to False, allowing exploration beyond the
bounds set by the ranges parameter.

dimensions = len(ranges)
upper = array([r[1] for r in ranges])
lower = array([r[0] for r in ranges])
shape = (particle_count, dimensions)
positions = randarray(shape)
for pos, rng in zip(positions, ranges):
pos *= rng[1] - rng[0]
pos += rng[0]
velocities = randarray(shape) / randarray(shape)
lbest = array([function(*pos) for pos in positions])
lbestat = array(positions)

gbestat = lbest.argmax()
gbest = lbest[gbestat]
gbestat = array(positions[gbestat])

while True:
velocities *= friction
velocities += randarray(shape) * (array([gbestat]).repeat(particle_count) - positions)
velocities += randarray(shape) * (lbestat - positions)

positions += velocities

if confine:
positions = minimum(maximum(positions, lower), upper)

values = array([function(*pos) for pos in positions])
replace = greater(values, lbest)
lbest[replace] = values[replace]
lbestat[replace] = positions[replace]

mval = values.max()

if mval > gbest:
gbest = mval
gbestat = positions[values.argmax()]

yield gbest, gbestat


Blogger Daniel Arbuckle said...

Whoops. Somewhere before the while loop, the following line should appear:
friction = min(max(friction, 0.0), 1.0)

8:24 AM  

Post a Comment

<< Home