mirror of
https://github.com/VCMP-SqMod/SqMod.git
synced 2025-01-19 03:57:14 +01:00
145 lines
4.9 KiB
C++
145 lines
4.9 KiB
C++
/**
|
|
* \file MPFRNormal.hpp
|
|
* \brief Header for MPFRNormal
|
|
*
|
|
* Sampling exactly from the normal distribution for MPFR.
|
|
*
|
|
* Copyright (c) Charles Karney (2012) <charles@karney.com> and licensed under
|
|
* the MIT/X11 License. For more information, see
|
|
* http://randomlib.sourceforge.net/
|
|
**********************************************************************/
|
|
|
|
#if !defined(RANDOMLIB_MPFRNORMAL_HPP)
|
|
#define RANDOMLIB_MPFRNORMAL_HPP 1
|
|
|
|
#include <algorithm> // for max/min
|
|
#include <RandomLib/MPFRRandom.hpp>
|
|
|
|
#if HAVE_MPFR || defined(DOXYGEN)
|
|
|
|
namespace RandomLib {
|
|
|
|
/**
|
|
* \brief The normal distribution for MPFR.
|
|
*
|
|
* This is a transcription of ExactNormal (version 1.3) for use with MPFR.
|
|
*
|
|
* This class uses mutable private objects. So a single MPFRNormal object
|
|
* cannot safely be used by multiple threads. In a multi-processing
|
|
* environment, each thread should use a thread-specific MPFRNormal object.
|
|
*
|
|
* @tparam bits the number of bits in each digit.
|
|
**********************************************************************/
|
|
template<int bits = 32> class MPFRNormal {
|
|
public:
|
|
|
|
/**
|
|
* Initialize the MPFRNormal object.
|
|
**********************************************************************/
|
|
MPFRNormal() { mpz_init(_tt); }
|
|
/**
|
|
* Destroy the MPFRNormal object.
|
|
**********************************************************************/
|
|
~MPFRNormal() { mpz_clear(_tt); }
|
|
/**
|
|
* Sample from the normal distribution with mean 0 and variance 1 returning
|
|
* a MPFRRandom.
|
|
*
|
|
* @param[out] t the MPFRRandom result.
|
|
* @param[in,out] r a GMP random generator.
|
|
**********************************************************************/
|
|
void operator()(MPFRRandom<bits>& t,gmp_randstate_t r) const
|
|
{ Compute(r); return _x.swap(t); }
|
|
/**
|
|
* Sample from the normal distribution with mean 0 and variance 1.
|
|
*
|
|
* @param[out] val the sample from the normal distribution
|
|
* @param[in,out] r a GMP random generator.
|
|
* @param[in] round the rounding direction.
|
|
* @return the MPFR ternary result (±1 if val is larger/smaller than
|
|
* the exact sample).
|
|
**********************************************************************/
|
|
int operator()(mpfr_t val, gmp_randstate_t r, mpfr_rnd_t round) const
|
|
{ Compute(r); return _x(val, r, round); }
|
|
private:
|
|
// Disable copy constructor and assignment operator
|
|
MPFRNormal(const MPFRNormal&);
|
|
MPFRNormal& operator=(const MPFRNormal&);
|
|
// True with prob exp(-1/2)
|
|
int ExpProbH(gmp_randstate_t r) const {
|
|
_p.Init(); if (_p.TestHighBit(r)) return 1;
|
|
// von Neumann rejection
|
|
while (true) {
|
|
_q.Init(); if (!_q.LessThan(r, _p)) return 0;
|
|
_p.Init(); if (!_p.LessThan(r, _q)) return 1;
|
|
}
|
|
}
|
|
// True with prob exp(-n/2)
|
|
int ExpProb(gmp_randstate_t r, unsigned n) const {
|
|
while (n--) { if (!ExpProbH(r)) return 0; }
|
|
return 1;
|
|
}
|
|
// n with prob (1-exp(-1/2)) * exp(-n/2)
|
|
unsigned ExpProbN(gmp_randstate_t r) const {
|
|
unsigned n = 0;
|
|
while (ExpProbH(r)) ++n;
|
|
return n;
|
|
}
|
|
// Return:
|
|
// 1 with prob 2k/(2k + 2)
|
|
// 0 with prob 1/(2k + 2)
|
|
// -1 with prob 1/(2k + 2)
|
|
int Choose(gmp_randstate_t r, int k) const {
|
|
const int b = 15; // To avoid integer overflow on multiplication
|
|
const int m = 2 * k + 2;
|
|
int n1 = m - 2, n2 = m - 1;
|
|
while (true) {
|
|
mpz_urandomb(_tt, r, b);
|
|
int d = int( mpz_get_ui(_tt) ) * m;
|
|
n1 = (std::max)((n1 << b) - d, 0);
|
|
if (n1 >= m) return 1;
|
|
n2 = (std::min)((n2 << b) - d, m);
|
|
if (n2 <= 0) return -1;
|
|
if (n1 == 0 && n2 == m) return 0;
|
|
}
|
|
}
|
|
void Compute(gmp_randstate_t r) const {
|
|
while (true) {
|
|
unsigned k = ExpProbN(r); // the integer part of the result.
|
|
if (ExpProb(r, (k - 1) * k)) {
|
|
_x.Init();
|
|
unsigned s = 1;
|
|
for (unsigned j = 0; j <= k; ++j) { // execute k + 1 times
|
|
bool first;
|
|
for (s = 1, first = true; ; s ^= 1, first = false) {
|
|
if (k == 0 && _x.Boolean(r)) break;
|
|
_q.Init(); if (!_q.LessThan(r, first ? _x : _p)) break;
|
|
int y = k == 0 ? 0 : Choose(r, k);
|
|
if (y < 0)
|
|
break;
|
|
else if (y == 0) {
|
|
_p.Init(); if (!_p.LessThan(r, _x)) break;
|
|
}
|
|
_p.swap(_q); // a fast way of doing p = q
|
|
}
|
|
if (s == 0) break;
|
|
}
|
|
if (s != 0) {
|
|
_x.AddInteger(k);
|
|
if (_x.Boolean(r)) _x.Negate();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
mutable mpz_t _tt; // A temporary
|
|
mutable MPFRRandom<bits> _x;
|
|
mutable MPFRRandom<bits> _p;
|
|
mutable MPFRRandom<bits> _q;
|
|
};
|
|
|
|
} // namespace RandomLib
|
|
|
|
#endif // HAVE_MPFR
|
|
#endif // RANDOMLIB_MPFRNORMAL_HPP
|