00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00014
00015
00016 #ifndef __defined_libdai_prob_h
00017 #define __defined_libdai_prob_h
00018
00019
00020 #include <cmath>
00021 #include <vector>
00022 #include <ostream>
00023 #include <algorithm>
00024 #include <numeric>
00025 #include <functional>
00026 #include <dai/util.h>
00027 #include <dai/exceptions.h>
00028
00029
00030 namespace dai {
00031
00032
00034 template<typename T> struct fo_id : public std::unary_function<T, T> {
00036 T operator()( const T &x ) const {
00037 return x;
00038 }
00039 };
00040
00041
00043 template<typename T> struct fo_abs : public std::unary_function<T, T> {
00045 T operator()( const T &x ) const {
00046 if( x < (T)0 )
00047 return -x;
00048 else
00049 return x;
00050 }
00051 };
00052
00053
00055 template<typename T> struct fo_exp : public std::unary_function<T, T> {
00057 T operator()( const T &x ) const {
00058 return exp( x );
00059 }
00060 };
00061
00062
00064 template<typename T> struct fo_log : public std::unary_function<T, T> {
00066 T operator()( const T &x ) const {
00067 return log( x );
00068 }
00069 };
00070
00071
00073 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
00075 T operator()( const T &x ) const {
00076 if( x )
00077 return log( x );
00078 else
00079 return 0;
00080 }
00081 };
00082
00083
00085 template<typename T> struct fo_inv : public std::unary_function<T, T> {
00087 T operator()( const T &x ) const {
00088 return 1 / x;
00089 }
00090 };
00091
00092
00094 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
00096 T operator()( const T &x ) const {
00097 if( x )
00098 return 1 / x;
00099 else
00100 return 0;
00101 }
00102 };
00103
00104
00106 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
00108 T operator()( const T &p ) const {
00109 return p * dai::log0(p);
00110 }
00111 };
00112
00113
00115 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
00117 T operator()( const T &x, const T &y ) const {
00118 if( y == (T)0 )
00119 return (T)0;
00120 else
00121 return x / y;
00122 }
00123 };
00124
00125
00127 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
00129 T operator()( const T &p, const T &q ) const {
00130 if( p == (T)0 )
00131 return (T)0;
00132 else
00133 return p * (log(p) - log(q));
00134 }
00135 };
00136
00137
00139 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> {
00141 T operator()( const T &p, const T &q ) const {
00142 T x = sqrt(p) - sqrt(q);
00143 return x * x;
00144 }
00145 };
00146
00147
00149 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
00151 T operator()( const T &x, const T &y ) const {
00152 if( y != 1 )
00153 return std::pow( x, y );
00154 else
00155 return x;
00156 }
00157 };
00158
00159
00161 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
00163 T operator()( const T &x, const T &y ) const {
00164 return (x > y) ? x : y;
00165 }
00166 };
00167
00168
00170 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
00172 T operator()( const T &x, const T &y ) const {
00173 return (x > y) ? y : x;
00174 }
00175 };
00176
00177
00179 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
00181 T operator()( const T &x, const T &y ) const {
00182 return dai::abs( x - y );
00183 }
00184 };
00185
00186
00188
00196 template <typename T> class TProb {
00197 private:
00199 std::vector<T> _p;
00200
00201 public:
00203
00207 typedef enum { NORMPROB, NORMLINF } NormType;
00209
00216 typedef enum { DISTL1, DISTLINF, DISTTV, DISTKL, DISTHEL } DistType;
00217
00219
00220
00221 TProb() : _p() {}
00222
00224 explicit TProb( size_t n ) : _p(std::vector<T>(n, (T)1 / n)) {}
00225
00227 explicit TProb( size_t n, T p ) : _p(n, p) {}
00228
00230
00235 template <typename TIterator>
00236 TProb( TIterator begin, TIterator end, size_t sizeHint=0 ) : _p() {
00237 _p.reserve( sizeHint );
00238 _p.insert( _p.begin(), begin, end );
00239 }
00240
00242
00245 template <typename S>
00246 TProb( const std::vector<S> &v ) : _p() {
00247 _p.reserve( v.size() );
00248 _p.insert( _p.begin(), v.begin(), v.end() );
00249 }
00251
00253 typedef typename std::vector<T>::const_iterator const_iterator;
00255 typedef typename std::vector<T>::iterator iterator;
00257 typedef typename std::vector<T>::const_reverse_iterator const_reverse_iterator;
00259 typedef typename std::vector<T>::reverse_iterator reverse_iterator;
00260
00262
00263
00264 iterator begin() { return _p.begin(); }
00266 const_iterator begin() const { return _p.begin(); }
00267
00269 iterator end() { return _p.end(); }
00271 const_iterator end() const { return _p.end(); }
00272
00274 reverse_iterator rbegin() { return _p.rbegin(); }
00276 const_reverse_iterator rbegin() const { return _p.rbegin(); }
00277
00279 reverse_iterator rend() { return _p.rend(); }
00281 const_reverse_iterator rend() const { return _p.rend(); }
00283
00285
00286
00287 const std::vector<T> & p() const { return _p; }
00288
00290 std::vector<T> & p() { return _p; }
00291
00293 T operator[]( size_t i ) const {
00294 #ifdef DAI_DEBUG
00295 return _p.at(i);
00296 #else
00297 return _p[i];
00298 #endif
00299 }
00300
00302 T& operator[]( size_t i ) { return _p[i]; }
00303
00305 size_t size() const { return _p.size(); }
00306
00308 template<typename binOp, typename unOp> T accumulate( T init, binOp op1, unOp op2 ) const {
00309 T t = init;
00310 for( const_iterator it = begin(); it != end(); it++ )
00311 t = op1( t, op2(*it) );
00312 return t;
00313 }
00314
00316 T entropy() const { return -accumulate( (T)0, std::plus<T>(), fo_plog0p<T>() ); }
00317
00319 T max() const { return accumulate( (T)(-INFINITY), fo_max<T>(), fo_id<T>() ); }
00320
00322 T min() const { return accumulate( (T)INFINITY, fo_min<T>(), fo_id<T>() ); }
00323
00325 T sum() const { return accumulate( (T)0, std::plus<T>(), fo_id<T>() ); }
00326
00328 T sumAbs() const { return accumulate( (T)0, std::plus<T>(), fo_abs<T>() ); }
00329
00331 T maxAbs() const { return accumulate( (T)0, fo_max<T>(), fo_abs<T>() ); }
00332
00334 bool hasNaNs() const {
00335 bool foundnan = false;
00336 for( typename std::vector<T>::const_iterator x = _p.begin(); x != _p.end(); x++ )
00337 if( isnan( *x ) ) {
00338 foundnan = true;
00339 break;
00340 }
00341 return foundnan;
00342 }
00343
00345 bool hasNegatives() const {
00346 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
00347 }
00348
00350 std::pair<size_t,T> argmax() const {
00351 T max = _p[0];
00352 size_t arg = 0;
00353 for( size_t i = 1; i < size(); i++ ) {
00354 if( _p[i] > max ) {
00355 max = _p[i];
00356 arg = i;
00357 }
00358 }
00359 return std::make_pair(arg,max);
00360 }
00361
00363 size_t draw() {
00364 Real x = rnd_uniform() * sum();
00365 T s = 0;
00366 for( size_t i = 0; i < size(); i++ ) {
00367 s += _p[i];
00368 if( s > x )
00369 return i;
00370 }
00371 return( size() - 1 );
00372 }
00373
00375
00377 bool operator<= (const TProb<T> & q) const {
00378 DAI_DEBASSERT( size() == q.size() );
00379 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
00380 }
00382
00384
00385
00386 template<typename unaryOp> TProb<T> pwUnaryTr( unaryOp op ) const {
00387 TProb<T> r;
00388 r._p.reserve( size() );
00389 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
00390 return r;
00391 }
00392
00394 TProb<T> operator- () const { return pwUnaryTr( std::negate<T>() ); }
00395
00397 TProb<T> abs() const { return pwUnaryTr( fo_abs<T>() ); }
00398
00400 TProb<T> exp() const { return pwUnaryTr( fo_exp<T>() ); }
00401
00403
00405 TProb<T> log(bool zero=false) const {
00406 if( zero )
00407 return pwUnaryTr( fo_log0<T>() );
00408 else
00409 return pwUnaryTr( fo_log<T>() );
00410 }
00411
00413
00415 TProb<T> inverse(bool zero=true) const {
00416 if( zero )
00417 return pwUnaryTr( fo_inv0<T>() );
00418 else
00419 return pwUnaryTr( fo_inv<T>() );
00420 }
00421
00423
00425 TProb<T> normalized( NormType norm = NORMPROB ) const {
00426 T Z = 0;
00427 if( norm == NORMPROB )
00428 Z = sum();
00429 else if( norm == NORMLINF )
00430 Z = maxAbs();
00431 if( Z == (T)0 ) {
00432 DAI_THROW(NOT_NORMALIZABLE);
00433 return *this;
00434 } else
00435 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
00436 }
00438
00440
00441
00442 template<typename unaryOp> TProb<T>& pwUnaryOp( unaryOp op ) {
00443 std::transform( _p.begin(), _p.end(), _p.begin(), op );
00444 return *this;
00445 }
00446
00448 TProb<T>& randomize() {
00449 std::generate( _p.begin(), _p.end(), rnd_uniform );
00450 return *this;
00451 }
00452
00454 TProb<T>& setUniform () {
00455 fill( (T)1 / size() );
00456 return *this;
00457 }
00458
00460 const TProb<T>& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
00461
00463 const TProb<T>& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
00464
00466
00468 const TProb<T>& takeLog(bool zero=false) {
00469 if( zero ) {
00470 return pwUnaryOp( fo_log0<T>() );
00471 } else
00472 return pwUnaryOp( fo_log<T>() );
00473 }
00474
00476
00478 T normalize( NormType norm=NORMPROB ) {
00479 T Z = 0;
00480 if( norm == NORMPROB )
00481 Z = sum();
00482 else if( norm == NORMLINF )
00483 Z = maxAbs();
00484 if( Z == (T)0 )
00485 DAI_THROW(NOT_NORMALIZABLE);
00486 else
00487 *this /= Z;
00488 return Z;
00489 }
00491
00493
00494
00495 TProb<T> & fill(T x) {
00496 std::fill( _p.begin(), _p.end(), x );
00497 return *this;
00498 }
00499
00501 TProb<T>& operator+= (T x) {
00502 if( x != 0 )
00503 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
00504 else
00505 return *this;
00506 }
00507
00509 TProb<T>& operator-= (T x) {
00510 if( x != 0 )
00511 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
00512 else
00513 return *this;
00514 }
00515
00517 TProb<T>& operator*= (T x) {
00518 if( x != 1 )
00519 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
00520 else
00521 return *this;
00522 }
00523
00525 TProb<T>& operator/= (T x) {
00526 DAI_DEBASSERT( x != 0 );
00527 if( x != 1 )
00528 return pwUnaryOp( std::bind2nd( std::divides<T>(), x ) );
00529 else
00530 return *this;
00531 }
00532
00534 TProb<T>& operator^= (T x) {
00535 if( x != (T)1 )
00536 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
00537 else
00538 return *this;
00539 }
00541
00543
00544
00545 TProb<T> operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
00546
00548 TProb<T> operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
00549
00551 TProb<T> operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
00552
00554 TProb<T> operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
00555
00557 TProb<T> operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
00559
00561
00562
00563
00567 template<typename binaryOp> TProb<T>& pwBinaryOp( const TProb<T> &q, binaryOp op ) {
00568 DAI_DEBASSERT( size() == q.size() );
00569 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
00570 return *this;
00571 }
00572
00574
00576 TProb<T>& operator+= (const TProb<T> & q) { return pwBinaryOp( q, std::plus<T>() ); }
00577
00579
00581 TProb<T>& operator-= (const TProb<T> & q) { return pwBinaryOp( q, std::minus<T>() ); }
00582
00584
00586 TProb<T>& operator*= (const TProb<T> & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
00587
00589
00592 TProb<T>& operator/= (const TProb<T> & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
00593
00595
00598 TProb<T>& divide (const TProb<T> & q) { return pwBinaryOp( q, std::divides<T>() ); }
00599
00601
00603 TProb<T>& operator^= (const TProb<T> & q) { return pwBinaryOp( q, fo_pow<T>() ); }
00605
00607
00608
00609
00613 template<typename binaryOp> TProb<T> pwBinaryTr( const TProb<T> &q, binaryOp op ) const {
00614 DAI_DEBASSERT( size() == q.size() );
00615 TProb<T> r;
00616 r._p.reserve( size() );
00617 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
00618 return r;
00619 }
00620
00622
00624 TProb<T> operator+ ( const TProb<T>& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
00625
00627
00629 TProb<T> operator- ( const TProb<T>& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
00630
00632
00634 TProb<T> operator* ( const TProb<T> &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
00635
00637
00640 TProb<T> operator/ ( const TProb<T> &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
00641
00643
00646 TProb<T> divided_by( const TProb<T> &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
00647
00649
00651 TProb<T> operator^ ( const TProb<T> &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
00653
00655
00657 template<typename binOp1, typename binOp2> T innerProduct( const TProb<T> &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
00658 DAI_DEBASSERT( size() == q.size() );
00659 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
00660 }
00661 };
00662
00663
00665
00668 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, typename TProb<T>::DistType dt ) {
00669 switch( dt ) {
00670 case TProb<T>::DISTL1:
00671 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
00672 case TProb<T>::DISTLINF:
00673 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
00674 case TProb<T>::DISTTV:
00675 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
00676 case TProb<T>::DISTKL:
00677 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
00678 case TProb<T>::DISTHEL:
00679 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2;
00680 default:
00681 DAI_THROW(UNKNOWN_ENUM_VALUE);
00682 return INFINITY;
00683 }
00684 }
00685
00686
00688
00690 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
00691 os << "[";
00692 std::copy( p.p().begin(), p.p().end(), std::ostream_iterator<T>(os, " ") );
00693 os << "]";
00694 return os;
00695 }
00696
00697
00699
00702 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
00703 return a.pwBinaryTr( b, fo_min<T>() );
00704 }
00705
00706
00708
00711 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
00712 return a.pwBinaryTr( b, fo_max<T>() );
00713 }
00714
00715
00717 typedef TProb<Real> Prob;
00718
00719
00720 }
00721
00722
00723 #endif