1
0
mirror of https://github.com/VCMP-SqMod/SqMod.git synced 2025-01-19 03:57:14 +01:00
SqMod/include/RandomLib/UniformInteger.hpp

254 lines
11 KiB
C++
Raw Normal View History

/**
* \file UniformInteger.hpp
* \brief Header for UniformInteger
*
* Partially sample a uniform integer distribution.
*
* Copyright (c) Charles Karney (2013) <charles@karney.com> and licensed
* under the MIT/X11 License. For more information, see
* http://randomlib.sourceforge.net/
**********************************************************************/
#if !defined(RANDOMLIB_UNIFORMINTEGER_HPP)
#define RANDOMLIB_UNIFORMINTEGER_HPP 1
#include <limits>
namespace RandomLib {
/**
* \brief The partial uniform integer distribution.
*
* A class to sample in [0, \e m). For background, see:
* - D. E. Knuth and A. C. Yao, The Complexity of Nonuniform Random Number
* Generation, in "Algorithms and Complexity" (Academic Press, 1976),
* pp. 357--428.
* - J. Lumbroso, Optimal Discrete Uniform Generation from Coin Flips,
* and Applications, http://arxiv.org/abs/1304.1916 (2013)
* .
* Lumbroso's algorithm is a realization of the Knuth-Yao method for the case
* of uniform probabilities. This class generalizes the method to accept
* random digits in a base, \e b = 2<sup>\e bits</sup>. An important
* additional feature is that only sufficient random digits are drawn to
* narrow the allowed range to a power of b. Thus after
* <code>UniformInteger<int,1> u(r,5)</code>, \e u represents \verbatim
range prob
[0,4) 8/15
[0,2) 2/15
[2,4) 2/15
4 1/5 \endverbatim
* <code>u.Min()</code> and <code>u.Max()</code> give the extent of the
* closed range. The number of additional random digits needed to fix the
* value is given by <code>u.Entropy()</code>. The comparison operations may
* require additional digits to be drawn and so the range might be narrowed
* down. If you need a definite value then use <code>u(r)</code>.
*
* The DiscreteNormalAlt class uses UniformInteger to achieve an
* asymptotically ideal scaling wherein the number of random bits required
* per sample is constant + log<sub>2</sub>&sigma;. If Lumbroso's algorithm
* for sampling in [0,\e m) were used the log<sub>2</sub>&sigma; term would
* be multiplied by about 1.4.
*
* It is instructive to look at the Knuth-Yao discrete distribution
* generating (DDG) tree for the case \e m = 5 (the binary expansion of 1/5
* is 0.00110011...); Lumbroso's algorithm implements this tree.
* \image html ky-5.png "Knuth-Yao for \e m = 5"
*
* UniformInteger collapses all of the full subtrees above to their parent
* nodes to yield this tree where now some of the outcomes are ranges.
* \image html ky-5-collapse.png "Collapsed Knuth-Yao for \e m = 5"
*
* Averaging over many samples, the maximum number of digits required to
* construct a UniformInteger, i.e., invoking
* <code>UniformInteger(r,m)</code>, is (2\e b &minus; 1)/(\e b &minus; 1).
* (Note that this does not increase as \e m increases.) The maximum number
* of digits required to sample specific integers, i.e., invoking
* <code>UniformInteger(r,m)(r)</code>, is <i>b</i>/(\e b &minus; 1) +
* log<sub>\e b</sub>\e m. The worst cases are when \e m is slightly more
* than a power of \e b.
*
* The number of random bits required for sampling is shown as a function of
* the fractional part of log<sub>2</sub>\e m below. The red line shows what
* Lumbroso calls the "toll", the number of bits in excess of the entropy
* that are required for sampling.
* \image html
* uniform-bits.png "Random bits to sample in [0,\e m) for \e b = 2"
*
* @tparam IntType the type of the integer (must be signed).
* @tparam bits the number of bits in each digit used for sampling;
* the base for sampling is \e b = 2<sup>\e bits</sup>.
**********************************************************************/
template<typename IntType = int, int bits = 1> class UniformInteger {
public:
/**
* Constructor creating a partially sampled integer in [0, \e m)
*
* @param[in] r random object.
* @param[in] m constructed object represents an integer in [0, \e m).
* @param[in] flip if true, rearrange the ranges so that the widest ones
* are at near the upper end of [0, \e m) (default false).
*
* The samples enough random digits to obtain a uniform range whose size is
* a power of the base. The range can subsequently be narrowed by sampling
* additional digits.
**********************************************************************/
template<class Random>
UniformInteger(Random& r, IntType m, bool flip = false);
/**
* @return the minimum of the current range.
**********************************************************************/
IntType Min() const { return _a; }
/**
* @return the maximum of the current range.
**********************************************************************/
IntType Max() const { return _a + (IntType(1) << (_l * bits)) - 1; }
/**
* @return the entropy of the current range (in units of random digits).
*
* Max() + 1 - Min() = 2<sup>Entropy() * \e bits</sup>.
**********************************************************************/
IntType Entropy() const { return _l; }
/**
* Sample until the entropy vanishes, i.e., Min() = Max().
*
* @return the resulting integer sample.
**********************************************************************/
template<class Random> IntType operator()(Random& r)
{ while (_l) Refine(r); return _a; }
/**
* Negate the range, [Min(), Max()] &rarr; [&minus;Max(), &minus;Min()].
**********************************************************************/
void Negate() { _a = -Max(); }
/**
* Add a constant to the range
*
* @param[in] c the constant to be added.
*
* [Min(), Max()] &rarr; [Min() + \e c, Max() + \e c].
**********************************************************************/
void Add(IntType c) { _a += c; }
/**
* Compare with a fraction, *this &lt; <i>p</i>/<i>q</i>
*
* @tparam Random the type of the random generator.
* @param[in,out] r a random generator.
* @param[in] p the numerator of the fraction.
* @param[in] q the denominator of the fraction (require \e q &gt; 0).
* @return true if *this &lt; <i>p</i>/<i>q</i>.
**********************************************************************/
// test j < p/q (require q > 0)
template<class Random> bool LessThan(Random& r, IntType p, IntType q) {
for (;;) {
if ( (q * Max() < p)) return true;
if (!(q * Min() < p)) return false;
Refine(r);
}
}
/**
* Compare with a fraction, *this &le; <i>p</i>/<i>q</i>
*
* @tparam Random the type of the random generator.
* @param[in,out] r a random generator.
* @param[in] p the numerator of the fraction.
* @param[in] q the denominator of the fraction (require \e q &gt; 0).
* @return true if *this &le; <i>p</i>/<i>q</i>.
**********************************************************************/
template<class Random>
bool LessThanEqual(Random& r, IntType p, IntType q)
{ return LessThan(r, p + 1, q); }
/**
* Compare with a fraction, *this &gt; <i>p</i>/<i>q</i>
*
* @tparam Random the type of the random generator.
* @param[in,out] r a random generator.
* @param[in] p the numerator of the fraction.
* @param[in] q the denominator of the fraction (require \e q &gt; 0).
* @return true if *this &gt; <i>p</i>/<i>q</i>.
**********************************************************************/
template<class Random>
bool GreaterThan(Random& r, IntType p, IntType q)
{ return !LessThanEqual(r, p, q); }
/**
* Compare with a fraction, *this &ge; <i>p</i>/<i>q</i>
*
* @tparam Random the type of the random generator.
* @param[in,out] r a random generator.
* @param[in] p the numerator of the fraction.
* @param[in] q the denominator of the fraction (require \e q &gt; 0).
* @return true if *this &ge; <i>p</i>/<i>q</i>.
**********************************************************************/
template<class Random>
bool GreaterThanEqual(Random& r, IntType p, IntType q)
{ return !LessThan(r, p, q); }
/**
* Check that overflow will not happen.
*
* @param[in] mmax the largest \e m in the constructor.
* @param[in] qmax the largest \e q in LessThan().
* @return true if overflow will not happen.
*
* It is important that this check be carried out. If overflow occurs,
* incorrect results are obtained and the constructor may never terminate.
**********************************************************************/
static bool Check(IntType mmax, IntType qmax) {
return ( mmax - 1 <= ((std::numeric_limits<IntType>::max)() >> bits) &&
mmax - 1 <= (std::numeric_limits<IntType>::max)() / qmax );
}
private:
IntType _a, _l; // current range is _a + [0, 2^(bits*_l)).
template<class Random> static unsigned RandomDigit(Random& r) throw()
{ return unsigned(r.template Integer<bits>()); }
template<class Random> void Refine(Random& r) // only gets called if _l > 0.
{ _a += IntType(RandomDigit(r) << (bits * --_l)); }
};
template<typename IntType, int bits> template<class Random>
UniformInteger<IntType, bits>::UniformInteger(Random& r, IntType m, bool flip)
{
STATIC_ASSERT(std::numeric_limits<IntType>::is_integer,
"UniformInteger: invalid integer type IntType");
STATIC_ASSERT(std::numeric_limits<IntType>::is_signed,
"UniformInteger: IntType must be a signed type");
STATIC_ASSERT(bits > 0 && bits < std::numeric_limits<IntType>::digits &&
bits <= std::numeric_limits<unsigned>::digits,
"UniformInteger: bits out of range");
m = m < 1 ? 1 : m;
for (IntType v = 1, c = 0;;) {
_l = 0; _a = c;
for (IntType w = v, a = c, d = 1;;) {
// play out Lumbroso's algorithm without drawing random digits with w
// playing the role of v and c represented by the range [a, a + d).
// Return if both ends of range qualify as return values at the same
// time. Otherwise, fail and draw another random digit.
if (w >= m) {
IntType j = (a / m) * m; a -= j; w -= j;
if (w >= m) {
if (a + d <= m) { _a = !flip ? a : m - a - d; return; }
break;
}
}
w <<= bits; a <<= bits; d <<= bits; ++_l;
}
IntType j = (v / m) * m; v -= j; c -= j;
v <<= bits; c <<= bits; c += IntType(RandomDigit(r));
}
}
/**
* \relates UniformInteger
* Print a UniformInteger. Format is [\e min,\e max] unless the entropy is
* zero, in which case it's \e val.
**********************************************************************/
template<typename IntType, int bits>
std::ostream& operator<<(std::ostream& os,
const UniformInteger<IntType, bits>& u) {
if (u.Entropy())
os << "[" << u.Min() << "," << u.Max() << "]";
else
os << u.Min();
return os;
}
} // namespace RandomLib
#endif // RANDOMLIB_UNIFORMINTEGER_HPP