mirror of
https://github.com/VCMP-SqMod/SqMod.git
synced 2025-01-19 03:57:14 +01:00
336 lines
12 KiB
C++
336 lines
12 KiB
C++
/**
|
|
* \file RandomSelect.hpp
|
|
* \brief Header for RandomSelect.
|
|
*
|
|
* An implementation of the Walker algorithm for selecting from a finite set.
|
|
*
|
|
* Copyright (c) Charles Karney (2006-2011) <charles@karney.com> and licensed
|
|
* under the MIT/X11 License. For more information, see
|
|
* http://randomlib.sourceforge.net/
|
|
**********************************************************************/
|
|
|
|
#if !defined(RANDOMLIB_RANDOMSELECT_HPP)
|
|
#define RANDOMLIB_RANDOMSELECT_HPP 1
|
|
|
|
#include <vector>
|
|
#include <limits>
|
|
#include <stdexcept>
|
|
|
|
#if defined(_MSC_VER)
|
|
// Squelch warnings about constant conditional expressions
|
|
# pragma warning (push)
|
|
# pragma warning (disable: 4127)
|
|
#endif
|
|
|
|
namespace RandomLib {
|
|
/**
|
|
* \brief Random selection from a discrete set.
|
|
*
|
|
* An implementation of Walker algorithm for selecting from a finite set
|
|
* (following Knuth, TAOCP, Vol 2, Sec 3.4.1.A). This provides a rapid way
|
|
* of selecting one of several choices depending on a discrete set weights.
|
|
* Original citation is\n A. J. Walker,\n An Efficient Method for Generating
|
|
* Discrete Random Variables and General Distributions,\n ACM TOMS 3,
|
|
* 253--256 (1977).
|
|
*
|
|
* There are two changes here in the setup algorithm as given by Knuth:
|
|
*
|
|
* - The probabilities aren't sorted at the beginning of the setup; nor are
|
|
* they maintained in a sorted order. Instead they are just partitioned on
|
|
* the mean. This improves the setup time from O(\e k<sup>2</sup>) to O(\e
|
|
* k).
|
|
*
|
|
* - The internal calculations are carried out with type \e NumericType. If
|
|
* the input weights are of integer type, then choosing an integer type for
|
|
* \e NumericType yields an exact solution for the returned distribution
|
|
* (assuming that the underlying random generator is exact.)
|
|
*
|
|
* Example:
|
|
* \code
|
|
#include <RandomLib/RandomSelect.hpp>
|
|
|
|
// Weights for throwing a pair of dice
|
|
unsigned w[] = { 0, 0, 1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1 };
|
|
|
|
// Initialize selection
|
|
RandomLib::RandomSelect<unsigned> sel(w, w + 13);
|
|
|
|
RandomLib::Random r; // Initialize random numbers
|
|
std::cout << "Seed set to " << r.SeedString() << "\n";
|
|
|
|
std::cout << "Throw a pair of dice 100 times:";
|
|
for (unsigned i = 0; i < 100; ++i)
|
|
std::cout << " " << sel(r);
|
|
std::cout << "\n";
|
|
\endcode
|
|
*
|
|
* @tparam NumericType the numeric type to use (default double).
|
|
**********************************************************************/
|
|
template<typename NumericType = double> class RandomSelect {
|
|
public:
|
|
/**
|
|
* Initialize in a cleared state (equivalent to having a single
|
|
* choice).
|
|
**********************************************************************/
|
|
RandomSelect() : _k(0), _wsum(0), _wmax(0) {}
|
|
|
|
/**
|
|
* Initialize with a weight vector \e w of elements of type \e WeightType.
|
|
* Internal calculations are carried out with type \e NumericType. \e
|
|
* NumericType needs to allow Choices() * MaxWeight() to be represented.
|
|
* Sensible combinations are:
|
|
* - \e WeightType integer, \e NumericType integer with
|
|
* digits(\e NumericType) ≥ digits(\e WeightType)
|
|
* - \e WeightType integer, \e NumericType real
|
|
* - \e WeightType real, \e NumericType real with digits(\e NumericType)
|
|
* ≥ digits(\e WeightType)
|
|
*
|
|
* @tparam WeightType the type of the weights.
|
|
* @param[in] w the vector of weights.
|
|
* @exception RandomErr if any of the weights are negative or if the total
|
|
* weight is not positive.
|
|
**********************************************************************/
|
|
template<typename WeightType>
|
|
RandomSelect(const std::vector<WeightType>& w) { Init(w.begin(), w.end()); }
|
|
|
|
/**
|
|
* Initialize with a weight given by a pair of iterators [\e a, \e b).
|
|
*
|
|
* @tparam InputIterator the type of the iterator.
|
|
* @param[in] a the beginning iterator.
|
|
* @param[in] b the ending iterator.
|
|
* @exception RandomErr if any of the weights are negative or if the total
|
|
* weight is not positive.
|
|
**********************************************************************/
|
|
template<typename InputIterator>
|
|
RandomSelect(InputIterator a, InputIterator b);
|
|
|
|
/**
|
|
* Clear the state (equivalent to having a single choice).
|
|
**********************************************************************/
|
|
void Init() throw()
|
|
{ _k = 0; _wsum = 0; _wmax = 0; _Q.clear(); _Y.clear(); }
|
|
|
|
/**
|
|
* Re-initialize with a weight vector \e w. Leave state unaltered in the
|
|
* case of an error.
|
|
*
|
|
* @tparam WeightType the type of the weights.
|
|
* @param[in] w the vector of weights.
|
|
**********************************************************************/
|
|
template<typename WeightType>
|
|
void Init(const std::vector<WeightType>& w) { Init(w.begin(), w.end()); }
|
|
|
|
/**
|
|
* Re-initialize with a weight given as a pair of iterators [\e a, \e b).
|
|
* Leave state unaltered in the case of an error.
|
|
*
|
|
* @tparam InputIterator the type of the iterator.
|
|
* @param[in] a the beginning iterator.
|
|
* @param[in] b the ending iterator.
|
|
**********************************************************************/
|
|
template<typename InputIterator>
|
|
void Init(InputIterator a, InputIterator b) {
|
|
RandomSelect<NumericType> t(a, b);
|
|
_Q.reserve(t._k);
|
|
_Y.reserve(t._k);
|
|
*this = t;
|
|
}
|
|
|
|
/**
|
|
* Return an index into the weight vector with probability proportional to
|
|
* the weight.
|
|
*
|
|
* @tparam Random the type of RandomCanonical generator.
|
|
* @param[in,out] r the RandomCanonical generator.
|
|
* @return the random index into the weight vector.
|
|
**********************************************************************/
|
|
template<class Random>
|
|
unsigned operator()(Random& r) const throw() {
|
|
if (_k <= 1)
|
|
return 0; // Special cases
|
|
const unsigned K = r.template Integer<unsigned>(_k);
|
|
// redundant casts to type NumericType to prevent warning from MS Project
|
|
return (std::numeric_limits<NumericType>::is_integer ?
|
|
r.template Prob<NumericType>(NumericType(_Q[K]),
|
|
NumericType(_wsum)) :
|
|
r.template Prob<NumericType>(NumericType(_Q[K]))) ?
|
|
K : _Y[K];
|
|
}
|
|
|
|
/**
|
|
* @return the sum of the weights.
|
|
**********************************************************************/
|
|
NumericType TotalWeight() const throw() { return _wsum; }
|
|
|
|
/**
|
|
* @return the maximum weight.
|
|
**********************************************************************/
|
|
NumericType MaxWeight() const throw() { return _wmax; }
|
|
|
|
/**
|
|
* @param[in] i the index in to the weight vector.
|
|
* @return the weight for sample \e i. Weight(i) / TotalWeight() gives the
|
|
* probability of sample \e i.
|
|
**********************************************************************/
|
|
NumericType Weight(unsigned i) const throw() {
|
|
if (i >= _k)
|
|
return NumericType(0);
|
|
else if (_k == 1)
|
|
return _wsum;
|
|
const NumericType n = std::numeric_limits<NumericType>::is_integer ?
|
|
_wsum : NumericType(1);
|
|
NumericType p = _Q[i];
|
|
for (unsigned j = _k; j;)
|
|
if (_Y[--j] == i)
|
|
p += n - _Q[j];
|
|
// If NumericType is integral, then p % _k == 0.
|
|
// assert(!std::numeric_limits<NumericType>::is_integer || p % _k == 0);
|
|
return (p / NumericType(_k)) * (_wsum / n);
|
|
}
|
|
|
|
/**
|
|
* @return the number of choices, i.e., the length of the weight vector.
|
|
**********************************************************************/
|
|
unsigned Choices() const throw() { return _k; }
|
|
|
|
private:
|
|
|
|
/**
|
|
* Size of weight vector
|
|
**********************************************************************/
|
|
unsigned _k;
|
|
/**
|
|
* Vector of cutoffs
|
|
**********************************************************************/
|
|
std::vector<NumericType> _Q;
|
|
/**
|
|
* Vector of aliases
|
|
**********************************************************************/
|
|
std::vector<unsigned> _Y;
|
|
/**
|
|
* The sum of the weights
|
|
**********************************************************************/
|
|
NumericType _wsum;
|
|
/**
|
|
* The maximum weight
|
|
**********************************************************************/
|
|
NumericType _wmax;
|
|
|
|
};
|
|
|
|
template<typename NumericType> template<typename InputIterator>
|
|
RandomSelect<NumericType>::RandomSelect(InputIterator a, InputIterator b) {
|
|
|
|
typedef typename std::iterator_traits<InputIterator>::value_type
|
|
WeightType;
|
|
// Disallow WeightType = real, NumericType = integer
|
|
STATIC_ASSERT(std::numeric_limits<WeightType>::is_integer ||
|
|
!std::numeric_limits<NumericType>::is_integer,
|
|
"RandomSelect: inconsistent WeightType and NumericType");
|
|
|
|
// If WeightType and NumericType are the same type, NumericType as precise
|
|
// as WeightType
|
|
STATIC_ASSERT(std::numeric_limits<WeightType>::is_integer !=
|
|
std::numeric_limits<NumericType>::is_integer ||
|
|
std::numeric_limits<NumericType>::digits >=
|
|
std::numeric_limits<WeightType>::digits,
|
|
"RandomSelect: NumericType insufficiently precise");
|
|
|
|
_wsum = 0;
|
|
_wmax = 0;
|
|
std::vector<NumericType> p;
|
|
|
|
for (InputIterator wptr = a; wptr != b; ++wptr) {
|
|
// Test *wptr < 0 without triggering compiler warning when *wptr =
|
|
// unsigned
|
|
if (!(*wptr > 0 || *wptr == 0))
|
|
// This also catches NaNs
|
|
throw RandomErr("RandomSelect: Illegal weight");
|
|
NumericType w = NumericType(*wptr);
|
|
if (w > (std::numeric_limits<NumericType>::max)() - _wsum)
|
|
throw RandomErr("RandomSelect: Overflow");
|
|
_wsum += w;
|
|
_wmax = w > _wmax ? w : _wmax;
|
|
p.push_back(w);
|
|
}
|
|
|
|
_k = unsigned(p.size());
|
|
if (_wsum <= 0)
|
|
throw RandomErr("RandomSelect: Zero total weight");
|
|
|
|
if (_k <= 1) {
|
|
// We treat k <= 1 as a special case in operator()
|
|
_Q.clear();
|
|
_Y.clear();
|
|
return;
|
|
}
|
|
|
|
if ((std::numeric_limits<NumericType>::max)()/NumericType(_k) <
|
|
NumericType(_wmax))
|
|
throw RandomErr("RandomSelect: Overflow");
|
|
|
|
std::vector<unsigned> j(_k);
|
|
_Q.resize(_k);
|
|
_Y.resize(_k);
|
|
|
|
// Pointers to the next empty low and high slots
|
|
unsigned u = 0;
|
|
unsigned v = _k - 1;
|
|
|
|
// Scale input and store in p and setup index array j. Note _wsum =
|
|
// mean(p). We could scale out _wsum here, but the following is exact when
|
|
// w[i] are low integers.
|
|
for (unsigned i = 0; i < _k; ++i) {
|
|
p[i] *= NumericType(_k);
|
|
j[p[i] > _wsum ? v-- : u++] = i;
|
|
}
|
|
|
|
// Pointers to the next low and high slots to use. Work towards the
|
|
// middle. This simplifies the loop exit test to u == v.
|
|
u = 0;
|
|
v = _k - 1;
|
|
|
|
// For integer NumericType, store the unnormalized probability in _Q and
|
|
// select using the exact Prob(_Q[k], _wsum). For real NumericType, store
|
|
// the normalized probability and select using Prob(_Q[k]). There will be
|
|
// a round off error in performing the division; but there is also the
|
|
// potential for round off errors in performing the arithmetic on p. There
|
|
// is therefore no point in simulating the division exactly using the
|
|
// slower Prob(real, real).
|
|
const NumericType n = std::numeric_limits<NumericType>::is_integer ?
|
|
NumericType(1) : _wsum;
|
|
|
|
while (true) {
|
|
// A loop invariant here is mean(p[j[u..v]]) == _wsum
|
|
_Q[j[u]] = p[j[u]] / n;
|
|
|
|
// If all arithmetic were exact this assignment could be:
|
|
// if (p[j[u]] < _wsum) _Y[j[u]] = j[v];
|
|
// But the following is safer:
|
|
_Y[j[u]] = j[p[j[u]] < _wsum ? v : u];
|
|
|
|
if (u == v) {
|
|
// The following assertion may fail because of roundoff errors
|
|
// assert( p[j[u]] == _wsum );
|
|
break;
|
|
}
|
|
|
|
// Update p, u, and v maintaining the loop invariant
|
|
p[j[v]] = p[j[v]] - (_wsum - p[j[u]]);
|
|
if (p[j[v]] > _wsum)
|
|
++u;
|
|
else
|
|
j[u] = j[v--];
|
|
}
|
|
return;
|
|
}
|
|
|
|
} // namespace RandomLib
|
|
|
|
#if defined(_MSC_VER)
|
|
# pragma warning (pop)
|
|
#endif
|
|
|
|
#endif // RANDOMLIB_RANDOMSELECT_HPP
|