1
0
mirror of https://github.com/VCMP-SqMod/SqMod.git synced 2025-01-19 03:57:14 +01:00
SqMod/include/RandomLib/MPFRNormal.hpp

145 lines
4.9 KiB
C++

/**
* \file MPFRNormal.hpp
* \brief Header for MPFRNormal
*
* Sampling exactly from the normal distribution for MPFR.
*
* Copyright (c) Charles Karney (2012) <charles@karney.com> and licensed under
* the MIT/X11 License. For more information, see
* http://randomlib.sourceforge.net/
**********************************************************************/
#if !defined(RANDOMLIB_MPFRNORMAL_HPP)
#define RANDOMLIB_MPFRNORMAL_HPP 1
#include <algorithm> // for max/min
#include <RandomLib/MPFRRandom.hpp>
#if HAVE_MPFR || defined(DOXYGEN)
namespace RandomLib {
/**
* \brief The normal distribution for MPFR.
*
* This is a transcription of ExactNormal (version 1.3) for use with MPFR.
*
* This class uses mutable private objects. So a single MPFRNormal object
* cannot safely be used by multiple threads. In a multi-processing
* environment, each thread should use a thread-specific MPFRNormal object.
*
* @tparam bits the number of bits in each digit.
**********************************************************************/
template<int bits = 32> class MPFRNormal {
public:
/**
* Initialize the MPFRNormal object.
**********************************************************************/
MPFRNormal() { mpz_init(_tt); }
/**
* Destroy the MPFRNormal object.
**********************************************************************/
~MPFRNormal() { mpz_clear(_tt); }
/**
* Sample from the normal distribution with mean 0 and variance 1 returning
* a MPFRRandom.
*
* @param[out] t the MPFRRandom result.
* @param[in,out] r a GMP random generator.
**********************************************************************/
void operator()(MPFRRandom<bits>& t,gmp_randstate_t r) const
{ Compute(r); return _x.swap(t); }
/**
* Sample from the normal distribution with mean 0 and variance 1.
*
* @param[out] val the sample from the normal distribution
* @param[in,out] r a GMP random generator.
* @param[in] round the rounding direction.
* @return the MPFR ternary result (&plusmn;1 if val is larger/smaller than
* the exact sample).
**********************************************************************/
int operator()(mpfr_t val, gmp_randstate_t r, mpfr_rnd_t round) const
{ Compute(r); return _x(val, r, round); }
private:
// Disable copy constructor and assignment operator
MPFRNormal(const MPFRNormal&);
MPFRNormal& operator=(const MPFRNormal&);
// True with prob exp(-1/2)
int ExpProbH(gmp_randstate_t r) const {
_p.Init(); if (_p.TestHighBit(r)) return 1;
// von Neumann rejection
while (true) {
_q.Init(); if (!_q.LessThan(r, _p)) return 0;
_p.Init(); if (!_p.LessThan(r, _q)) return 1;
}
}
// True with prob exp(-n/2)
int ExpProb(gmp_randstate_t r, unsigned n) const {
while (n--) { if (!ExpProbH(r)) return 0; }
return 1;
}
// n with prob (1-exp(-1/2)) * exp(-n/2)
unsigned ExpProbN(gmp_randstate_t r) const {
unsigned n = 0;
while (ExpProbH(r)) ++n;
return n;
}
// Return:
// 1 with prob 2k/(2k + 2)
// 0 with prob 1/(2k + 2)
// -1 with prob 1/(2k + 2)
int Choose(gmp_randstate_t r, int k) const {
const int b = 15; // To avoid integer overflow on multiplication
const int m = 2 * k + 2;
int n1 = m - 2, n2 = m - 1;
while (true) {
mpz_urandomb(_tt, r, b);
int d = int( mpz_get_ui(_tt) ) * m;
n1 = (std::max)((n1 << b) - d, 0);
if (n1 >= m) return 1;
n2 = (std::min)((n2 << b) - d, m);
if (n2 <= 0) return -1;
if (n1 == 0 && n2 == m) return 0;
}
}
void Compute(gmp_randstate_t r) const {
while (true) {
unsigned k = ExpProbN(r); // the integer part of the result.
if (ExpProb(r, (k - 1) * k)) {
_x.Init();
unsigned s = 1;
for (unsigned j = 0; j <= k; ++j) { // execute k + 1 times
bool first;
for (s = 1, first = true; ; s ^= 1, first = false) {
if (k == 0 && _x.Boolean(r)) break;
_q.Init(); if (!_q.LessThan(r, first ? _x : _p)) break;
int y = k == 0 ? 0 : Choose(r, k);
if (y < 0)
break;
else if (y == 0) {
_p.Init(); if (!_p.LessThan(r, _x)) break;
}
_p.swap(_q); // a fast way of doing p = q
}
if (s == 0) break;
}
if (s != 0) {
_x.AddInteger(k);
if (_x.Boolean(r)) _x.Negate();
return;
}
}
}
}
mutable mpz_t _tt; // A temporary
mutable MPFRRandom<bits> _x;
mutable MPFRRandom<bits> _p;
mutable MPFRRandom<bits> _q;
};
} // namespace RandomLib
#endif // HAVE_MPFR
#endif // RANDOMLIB_MPFRNORMAL_HPP