1
0
mirror of https://github.com/VCMP-SqMod/SqMod.git synced 2025-02-12 07:47:12 +01:00
SqMod/include/RandomLib/MPFRRandom.hpp

384 lines
15 KiB
C++
Raw Normal View History

/**
* \file MPFRRandom.hpp
* \brief Header for MPFRRandom
*
* Utility class for MPFRUniform, MPFRExponential, and MPFRNormal.
*
* Copyright (c) Charles Karney (2012) <charles@karney.com> and licensed under
* the MIT/X11 License. For more information, see
* http://randomlib.sourceforge.net/
**********************************************************************/
#if !defined(RANDOMLIB_MPFRRANDOM_HPP)
#define RANDOMLIB_MPFRRANDOM_HPP 1
#include <algorithm> // for swap
#include <mpfr.h>
#define HAVE_MPFR (MPFR_VERSION_MAJOR >= 3)
#if HAVE_MPFR || defined(DOXYGEN)
/**
* A compile-time assert. Use C++11 static_assert, if available.
**********************************************************************/
#if !defined(STATIC_ASSERT)
# if defined(__GXX_EXPERIMENTAL_CXX0X__)
# define STATIC_ASSERT static_assert
# elif defined(_MSC_VER) && _MSC_VER >= 1600
# define STATIC_ASSERT static_assert
# else
# define STATIC_ASSERT(cond,reason) \
{ enum{ STATIC_ASSERT_ENUM = 1/int(cond) }; }
# endif
#endif
namespace RandomLib {
/**
* \brief Handling random numbers in MPFR.
*
* This class provides roughly the same capabilities as RandomNumber. The
* fraction is represented by a mpz integer \e f and an exponent \e e. We
* have \e e &ge; 0 and 0 &le; \e f < <i>b</i><sup><i>e</i></sup>, and \e b =
* 2<sup><i>bits</i></sup>. This represents the number \e x = \e f
* <i>b</i><sup>&minus;<i>e</i></sup>, with x in [0, 1).
*
* @tparam bits the number of bits in each digit.
*
* \e bits must divide GMP_LIMB_BITS. The default value \e bits = 32 yields
* portable results on all MPFR platforms.
**********************************************************************/
template<int bits = 32> class MPFRRandom {
private:
static const int limb_ = GMP_LIMB_BITS; // How many bits in a limb
static const int loglimb_ = (limb_ == 32 ? 5 :
(limb_ == 64 ? 6 :
(limb_ == 128 ? 7 : -1)));
static const int logbits_ = (bits == 1 ? 0 :
(bits == 2 ? 1 :
(bits == 4 ? 2 :
(bits == 8 ? 3 :
(bits == 16 ? 4 :
(bits == 32 ? 5 :
(bits == 64 ? 6 :
(bits == 128 ? 7 : -1))))))));
static const mp_limb_t mask_ = (bits == limb_ ? ~0UL : // Digit mask
~(~0UL << (bits < limb_ ? bits : 0)));
static const int logw_ = loglimb_ - logbits_; // 2^logw digits per limb
static const unsigned wmask_ = ~(~0U << logw_);
mutable mpz_t _tt; // A temporary
mpz_t _f; // The fraction
mp_size_t _e; // Count of digits
unsigned long _n; // Integer part
int _s; // Sign
void AddDigits(gmp_randstate_t r, long num = 1) { // Add num more digits
if (num <= 0) return;
mpz_mul_2exp(_f, _f, num << logbits_);
mpz_urandomb(_tt, r, num << logbits_);
mpz_add(_f, _f, _tt);
_e += num;
}
// return k'th digit counting k = 0 as most significant
mp_limb_t Digit(gmp_randstate_t r, mp_size_t k) {
ExpandTo(r, k); // Now e > k
k = _e - 1 - k; // Reverse k so k = 0 is least significant
// (k >> logw) is the limb index
// (k & wmask) is the digit position within the limb
return mask_ &
(mpz_getlimbn(_f, k >> logw_) >> ((k & wmask_) << logbits_));
}
// Return index [0..32] of highest bit set. Return 0 if x = 0, 32 is if x
// = ~0. (From Algorithms for programmers by Joerg Arndt.)
static int highest_bit_idx(unsigned long x) throw() {
if (x == 0) return 0;
int r = 1;
// STILL TO DO: handle 64-bit unsigned longs.
if (x & 0xffff0000UL) { x >>= 16; r += 16; }
if (x & 0x0000ff00UL) { x >>= 8; r += 8; }
if (x & 0x000000f0UL) { x >>= 4; r += 4; }
if (x & 0x0000000cUL) { x >>= 2; r += 2; }
if (x & 0x00000002UL) { r += 1; }
return r;
}
public:
/**
* Initialize the MPFRRandom object.
**********************************************************************/
MPFRRandom() : _e(0u), _n(0u), _s(1) {
STATIC_ASSERT(logbits_ >= 0 && loglimb_ >= 0 && logbits_ <= loglimb_,
"MPRFRandom: unsupported value for bits");
mpz_init(_f); mpz_init(_tt);
}
/**
* Initialize the MPFRRandom object from another one.
*
* @param[in] t the MPFRRandom to copy.
**********************************************************************/
MPFRRandom(const MPFRRandom& t) : _e(t._e), _n(t._n), _s(t._s)
{ mpz_init(_f); mpz_set(_f, t._f); mpz_init(_tt); }
/**
* Destroy the MPFRRandom object.
**********************************************************************/
~MPFRRandom() { mpz_clear(_f); mpz_clear(_tt); }
/**
* Assignment operator. (But swapping is typically faster.)
*
* @param[in] t the MPFRRandom to copy.
**********************************************************************/
MPFRRandom& operator=(const MPFRRandom& t) {
_e = t._e;
_n = t._n;
_s = t._s;
mpz_set(_f, t._f); // Don't copy _tt
return *this;
}
/**
* Swap with another MPFRRandom. This is a fast way of doing an
* assignment.
*
* @param[in,out] t the MPFRRandom to swap with.
**********************************************************************/
void swap(MPFRRandom& t) throw() {
if (this != &t) {
std::swap(_e, t._e);
std::swap(_n, t._n);
std::swap(_s, t._s);
mpz_swap(_f, t._f); // Don't swap _tt
}
}
/**
* Reinitialize the MPFRRandom object, setting its value to [0,1].
**********************************************************************/
void Init() { mpz_set_ui(_f, 0u); _e = 0; _n = 0; _s = 1; }
/**
* @return the sign of the MPFRRandom (&plusmn; 1).
**********************************************************************/
int Sign() const throw() { return _s; }
/**
* Change the sign of the MPFRRandom.
**********************************************************************/
void Negate() throw() { _s *= -1; }
/**
* @return the floor of the MPFRRandom
**********************************************************************/
long Floor() const throw() { return _s > 0 ? long(_n) : -1 - long(_n); }
/**
* @return the ceiling of the MPFRRandom
**********************************************************************/
long Ceiling() const throw() { return _s > 0 ? 1 + long(_n) : -long(_n); }
/**
* @return the unsigned integer component of the MPFRRandom.
**********************************************************************/
unsigned long UInteger() const throw() { return _n; }
/**
* @return the number of digits in fraction
**********************************************************************/
unsigned long Size() const throw() { return unsigned(_e); }
/**
* Add integer \e k to the MPRFRandom.
*
* @param[in] k the integer to add.
**********************************************************************/
void AddInteger(long k) {
k += Floor(); // The new floor
int ns = k < 0 ? -1 : 1; // The new sign
if (ns != _s) { // If sign changes, set f = 1 - f
mpz_set_ui(_tt, 1u);
mpz_mul_2exp(_tt, _tt, _e << logbits_);
mpz_sub_ui(_tt, _tt, 1u);
mpz_sub(_f, _tt, _f);
_s = ns;
}
_n = ns > 0 ? k : -(k + 1);
}
/**
* Compare with another MPFRRandom, *this < \e t.
*
* @param[in,out] r a random generator.
* @param[in,out] t a MPFRRandom to compare.
* @return true if *this < \e t.
**********************************************************************/
int LessThan(gmp_randstate_t r, MPFRRandom& t) {
if (this == &t) return false; // same object
if (_s != t._s) return _s < t._s;
if (_n != t._n) return (_s < 0) ^ (_n < t._n);
for (mp_size_t k = 0; ; ++k) {
mp_limb_t x = Digit(r, k);
mp_limb_t y = t.Digit(r, k);
if (x != y) return (_s < 0) ^ (x < y);
}
}
/**
* Set high bit of fraction to 1.
*
* @param[in,out] r a random generator.
**********************************************************************/
void SetHighBit(gmp_randstate_t r) { // Set the msb to 1
ExpandTo(r, 0); // Generate msb if necessary
mpz_setbit(_f, (_e << logbits_) - 1);
}
/**
* Test high bit of fraction.
*
* @param[in,out] r a random generator.
**********************************************************************/
int TestHighBit(gmp_randstate_t r) { // test the msb of f
ExpandTo(r, 0); // Generate msb if necessary
return mpz_tstbit(_f, (_e << logbits_) - 1);
}
/**
* Return the position of the most significant bit in the MPFRRandom.
*
* @param[in,out] r a random generator.
*
* The bit position is numbered such the 1/2 bit is 0, the 1/4 bit is -1,
* etc.
**********************************************************************/
mp_size_t LeadingBit(gmp_randstate_t r) {
if (_n) return highest_bit_idx(_n);
while (true) {
int sgn = mpz_sgn(_f);
if (sgn != 0)
return mp_size_t(mpz_sizeinbase(_f, 2)) - mp_size_t(_e << logbits_);
AddDigits(r);
}
}
/**
* Ensure that the k'th digit of the fraction is computed.
*
* @param[in,out] r a random generator.
* @param[in] k the digit number (0 is the most significant, 1 is the next
* most significant, etc.
**********************************************************************/
void ExpandTo(gmp_randstate_t r, mp_size_t k)
{ if (_e <= k) AddDigits(r, k - _e + 1); }
/**
* Convert to a MPFR number \e without adding more bits.
*
* @param[out] val the value of s * (n + *this).
* @param[in] round the rounding direction.
* @return the MPFR ternary result (&plusmn; if val is larger/smaller than
* the exact sample).
*
* If round is MPFR_RNDN, then the rounded midpoint of the interval
* represented by the MPFRRandom is returned. Otherwise it is the rounded
* lower or upper bound of the interval (whichever is appropriate).
**********************************************************************/
int operator()(mpfr_t val, mpfr_rnd_t round)
{ return operator()(val, NULL, round); }
/**
* Convert to a MPFR number.
*
* @param[out] val the value of s * (n + *this).
* @param[in,out] r a GMP random generator.
* @param[in] round the rounding direction.
* @return the MPFR ternary result (&plusmn; if val is larger/smaller than
* the exact sample).
*
* If \e r is NULL, then no additional random bits are generated and the
* lower bound, midpoint, or upper bound of the MPFRRandom interval is
* returned, depending on the value of \e round.
**********************************************************************/
int operator()(mpfr_t val, gmp_randstate_t r, mpfr_rnd_t round) {
// The value is constructed as a positive quantity, so adjust rounding
// mode to account for this.
switch (round) {
case MPFR_RNDD:
case MPFR_RNDU:
case MPFR_RNDN:
break;
case MPFR_RNDZ:
round = _s < 0 ? MPFR_RNDU : MPFR_RNDD;
break;
case MPFR_RNDA:
round = _s < 0 ? MPFR_RNDD : MPFR_RNDU;
break;
default:
round = MPFR_RNDN; // New rounding modes are variants of N
break;
} // Now round is one of MPFR_RND{D,N,U}
mp_size_t excess;
mpfr_exp_t expt;
if (r == NULL) {
// If r is NULL then all the bits currently generated are considered
// significant. Thus no excess bits need to be squeezed out.
excess = 0;
// And the exponent shift in mpfr_set_z_2exp is just...
expt = -(_e << logbits_);
// However, if rounding to nearest, we need to make room for the
// midpoint bit.
if (round == MPFR_RNDN) {
excess = -1;
--expt;
}
} else { // r is non-NULL
// Generate enough digits, i.e., enough to generate prec significant
// figures for RNDD and RNDU; for RNDN we need to generate an
// additional guard bit.
mp_size_t lead = LeadingBit(r);
mpfr_prec_t prec = mpfr_get_prec (val);
mp_size_t trail = lead - prec; // position one past trailing bit
mp_size_t guard = trail + (round == MPFR_RNDN ? 0 : 1); // guard bit pos
// Generate the bits needed.
if (guard <= 0) ExpandTo(r, (-guard) >> logbits_);
// Unless bits = 1, the generation process will typically have
// generated too many bits. We figure out how many, but leaving room
// for one additional "inexact" bit. The inexact bit is set to 1 in
// order to force MPFR to treat the result as inexact, to break RNDN
// ties, and to get the ternary value set correctly.
//
// expt is the exponent used when forming the number using
// mpfr_set_z_2exp. Without the inexact bit, it's (guard - 1).
// Subtract 1 to account for the inexact bit.
expt = guard - 2;
// The number of excess bits is now the difference between the number
// of bits in the fraction (e << logbits) and -expt. Note that this
// may be -1 (meaning we'll need to shift the number left to
// accommodate the inexact bit).
excess = (_e << logbits_) + expt;
}
mpz_set_ui(_tt, _n); // The integer part
mpz_mul_2exp(_tt, _tt, _e << logbits_); // Shift to allow for fraction
mpz_add(_tt, _tt, _f); // Add fraction
if (excess > 0)
mpz_tdiv_q_2exp(_tt, _tt, excess);
else if (excess < 0)
mpz_mul_2exp(_tt, _tt, -excess);
if (r || round == MPFR_RNDN)
// Set the inexact bit (or compute the midpoint if r is NULL).
mpz_setbit(_tt, 0);
else if (round == MPFR_RNDU)
// If r is NULL, compute the upper bound.
mpz_add_ui(_tt, _tt, 1u);
// Convert to a mpfr number. If r is specified, then there are
// sufficient bits in tt that the result is inexact and that (in the case
// of RNDN) there are no ties.
int flag = mpfr_set_z_2exp(val, _tt, expt, round);
if (_s < 0) {
mpfr_neg (val, val, MPFR_RNDN);
flag = -flag;
}
return flag;
}
/**
* A coin toss. (This should really be a static function. But it uses the
* MPFRRandom temporary variable.)
*
* @param[in,out] r a GMP random generator.
* @return true or false.
**********************************************************************/
int Boolean(gmp_randstate_t r) const {
mpz_urandomb(_tt, r, 1);
return mpz_tstbit(_tt, 0);
}
};
} // namespace RandomLib
#endif // HAVE_MPFR
#endif // RANDOMLIB_MPFRRANDOM_HPP