00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00014
00015
00016 #ifndef __defined_libdai_prob_h
00017 #define __defined_libdai_prob_h
00018
00019
00020 #include <cmath>
00021 #include <vector>
00022 #include <ostream>
00023 #include <algorithm>
00024 #include <numeric>
00025 #include <functional>
00026 #include <dai/util.h>
00027 #include <dai/exceptions.h>
00028
00029
00030 namespace dai {
00031
00032
00034 template<typename T> struct fo_id : public std::unary_function<T, T> {
00036 T operator()( const T &x ) const {
00037 return x;
00038 }
00039 };
00040
00041
00043 template<typename T> struct fo_abs : public std::unary_function<T, T> {
00045 T operator()( const T &x ) const {
00046 if( x < (T)0 )
00047 return -x;
00048 else
00049 return x;
00050 }
00051 };
00052
00053
00055 template<typename T> struct fo_exp : public std::unary_function<T, T> {
00057 T operator()( const T &x ) const {
00058 return exp( x );
00059 }
00060 };
00061
00062
00064 template<typename T> struct fo_log : public std::unary_function<T, T> {
00066 T operator()( const T &x ) const {
00067 return log( x );
00068 }
00069 };
00070
00071
00073 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
00075 T operator()( const T &x ) const {
00076 if( x )
00077 return log( x );
00078 else
00079 return 0;
00080 }
00081 };
00082
00083
00085 template<typename T> struct fo_inv : public std::unary_function<T, T> {
00087 T operator()( const T &x ) const {
00088 return 1 / x;
00089 }
00090 };
00091
00092
00094 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
00096 T operator()( const T &x ) const {
00097 if( x )
00098 return 1 / x;
00099 else
00100 return 0;
00101 }
00102 };
00103
00104
00106 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
00108 T operator()( const T &p ) const {
00109 return p * dai::log0(p);
00110 }
00111 };
00112
00113
00115 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
00117 T operator()( const T &x, const T &y ) const {
00118 if( y == (T)0 )
00119 return (T)0;
00120 else
00121 return x / y;
00122 }
00123 };
00124
00125
00127 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
00129 T operator()( const T &p, const T &q ) const {
00130 if( p == (T)0 )
00131 return (T)0;
00132 else
00133 return p * (log(p) - log(q));
00134 }
00135 };
00136
00137
00139 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> {
00141 T operator()( const T &p, const T &q ) const {
00142 T x = sqrt(p) - sqrt(q);
00143 return x * x;
00144 }
00145 };
00146
00147
00149 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
00151 T operator()( const T &x, const T &y ) const {
00152 if( y != 1 )
00153 return pow( x, y );
00154 else
00155 return x;
00156 }
00157 };
00158
00159
00161 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
00163 T operator()( const T &x, const T &y ) const {
00164 return (x > y) ? x : y;
00165 }
00166 };
00167
00168
00170 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
00172 T operator()( const T &x, const T &y ) const {
00173 return (x > y) ? y : x;
00174 }
00175 };
00176
00177
00179 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
00181 T operator()( const T &x, const T &y ) const {
00182 return dai::abs( x - y );
00183 }
00184 };
00185
00186
00188
00196 template <typename T>
00197 class TProb {
00198 public:
00200 typedef std::vector<T> container_type;
00201
00203 typedef TProb<T> this_type;
00204
00205 private:
00207 container_type _p;
00208
00209 public:
00211
00216 typedef enum { NORMPROB, NORMLINF } NormType;
00218
00226 typedef enum { DISTL1, DISTLINF, DISTTV, DISTKL, DISTHEL } DistType;
00227
00229
00230
00231 TProb() : _p() {}
00232
00234 explicit TProb( size_t n ) : _p( n, (T)1 / n ) {}
00235
00237 explicit TProb( size_t n, T p ) : _p( n, p ) {}
00238
00240
00247 template <typename TIterator>
00248 TProb( TIterator begin, TIterator end, size_t sizeHint=0 ) : _p() {
00249 _p.reserve( sizeHint );
00250 _p.insert( _p.begin(), begin, end );
00251 }
00252
00254
00257 template <typename S>
00258 TProb( const std::vector<S> &v ) : _p() {
00259 _p.reserve( v.size() );
00260 _p.insert( _p.begin(), v.begin(), v.end() );
00261 }
00263
00265 typedef typename container_type::const_iterator const_iterator;
00267 typedef typename container_type::iterator iterator;
00269 typedef typename container_type::const_reverse_iterator const_reverse_iterator;
00271 typedef typename container_type::reverse_iterator reverse_iterator;
00272
00274
00275
00276 iterator begin() { return _p.begin(); }
00278 const_iterator begin() const { return _p.begin(); }
00279
00281 iterator end() { return _p.end(); }
00283 const_iterator end() const { return _p.end(); }
00284
00286 reverse_iterator rbegin() { return _p.rbegin(); }
00288 const_reverse_iterator rbegin() const { return _p.rbegin(); }
00289
00291 reverse_iterator rend() { return _p.rend(); }
00293 const_reverse_iterator rend() const { return _p.rend(); }
00295
00297
00298 void resize( size_t sz ) {
00299 _p.resize( sz );
00300 }
00302
00304
00305
00306 T get( size_t i ) const {
00307 #ifdef DAI_DEBUG
00308 return _p.at(i);
00309 #else
00310 return _p[i];
00311 #endif
00312 }
00313
00315 void set( size_t i, T val ) {
00316 DAI_DEBASSERT( i < _p.size() );
00317 _p[i] = val;
00318 }
00320
00322
00323
00324 const container_type& p() const { return _p; }
00325
00327 container_type& p() { return _p; }
00328
00330 T operator[]( size_t i ) const { return get(i); }
00331
00333
00335 T& operator[]( size_t i ) { return _p[i]; }
00336
00338 size_t size() const { return _p.size(); }
00339
00341
00350 template<typename binOp, typename unOp> T accumulate( T init, binOp op1, unOp op2 ) const {
00351 T t = op2(init);
00352 for( const_iterator it = begin(); it != end(); it++ )
00353 t = op1( t, op2(*it) );
00354 return t;
00355 }
00356
00357
00359
00367 template<typename unOp> T accumulateSum( T init, unOp op ) const {
00368 T t = op(init);
00369 for( const_iterator it = begin(); it != end(); it++ )
00370 t += op(*it);
00371 return t;
00372 }
00373
00375
00383 template<typename unOp> T accumulateMax( T init, unOp op, bool minimize ) const {
00384 T t = op(init);
00385 if( minimize ) {
00386 for( const_iterator it = begin(); it != end(); it++ )
00387 t = std::min( t, op(*it) );
00388 } else {
00389 for( const_iterator it = begin(); it != end(); it++ )
00390 t = std::max( t, op(*it) );
00391 }
00392 return t;
00393 }
00394
00396 T entropy() const { return -accumulateSum( (T)0, fo_plog0p<T>() ); }
00397
00399 T max() const { return accumulateMax( (T)(-INFINITY), fo_id<T>(), false ); }
00400
00402 T min() const { return accumulateMax( (T)INFINITY, fo_id<T>(), true ); }
00403
00405 T sum() const { return accumulateSum( (T)0, fo_id<T>() ); }
00406
00408 T sumAbs() const { return accumulateSum( (T)0, fo_abs<T>() ); }
00409
00411 T maxAbs() const { return accumulateMax( (T)0, fo_abs<T>(), false ); }
00412
00414 bool hasNaNs() const {
00415 bool foundnan = false;
00416 for( const_iterator x = _p.begin(); x != _p.end(); x++ )
00417 if( isnan( *x ) ) {
00418 foundnan = true;
00419 break;
00420 }
00421 return foundnan;
00422 }
00423
00425 bool hasNegatives() const {
00426 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
00427 }
00428
00430 std::pair<size_t,T> argmax() const {
00431 T max = _p[0];
00432 size_t arg = 0;
00433 for( size_t i = 1; i < size(); i++ ) {
00434 if( _p[i] > max ) {
00435 max = _p[i];
00436 arg = i;
00437 }
00438 }
00439 return std::make_pair( arg, max );
00440 }
00441
00443 size_t draw() {
00444 Real x = rnd_uniform() * sum();
00445 T s = 0;
00446 for( size_t i = 0; i < size(); i++ ) {
00447 s += get(i);
00448 if( s > x )
00449 return i;
00450 }
00451 return( size() - 1 );
00452 }
00453
00455
00457 bool operator<( const this_type& q ) const {
00458 DAI_DEBASSERT( size() == q.size() );
00459 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
00460 }
00461
00463 bool operator==( const this_type& q ) const {
00464 if( size() != q.size() )
00465 return false;
00466 return p() == q.p();
00467 }
00469
00471
00472
00473 template<typename unaryOp> this_type pwUnaryTr( unaryOp op ) const {
00474 this_type r;
00475 r._p.reserve( size() );
00476 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
00477 return r;
00478 }
00479
00481 this_type operator- () const { return pwUnaryTr( std::negate<T>() ); }
00482
00484 this_type abs() const { return pwUnaryTr( fo_abs<T>() ); }
00485
00487 this_type exp() const { return pwUnaryTr( fo_exp<T>() ); }
00488
00490
00492 this_type log(bool zero=false) const {
00493 if( zero )
00494 return pwUnaryTr( fo_log0<T>() );
00495 else
00496 return pwUnaryTr( fo_log<T>() );
00497 }
00498
00500
00502 this_type inverse(bool zero=true) const {
00503 if( zero )
00504 return pwUnaryTr( fo_inv0<T>() );
00505 else
00506 return pwUnaryTr( fo_inv<T>() );
00507 }
00508
00510
00512 this_type normalized( ProbNormType norm = dai::NORMPROB ) const {
00513 T Z = 0;
00514 if( norm == dai::NORMPROB )
00515 Z = sum();
00516 else if( norm == dai::NORMLINF )
00517 Z = maxAbs();
00518 if( Z == (T)0 ) {
00519 DAI_THROW(NOT_NORMALIZABLE);
00520 return *this;
00521 } else
00522 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
00523 }
00525
00527
00528
00529 template<typename unaryOp> this_type& pwUnaryOp( unaryOp op ) {
00530 std::transform( _p.begin(), _p.end(), _p.begin(), op );
00531 return *this;
00532 }
00533
00535 this_type& randomize() {
00536 std::generate( _p.begin(), _p.end(), rnd_uniform );
00537 return *this;
00538 }
00539
00541 this_type& setUniform () {
00542 fill( (T)1 / size() );
00543 return *this;
00544 }
00545
00547 this_type& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
00548
00550 this_type& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
00551
00553
00555 this_type& takeLog(bool zero=false) {
00556 if( zero ) {
00557 return pwUnaryOp( fo_log0<T>() );
00558 } else
00559 return pwUnaryOp( fo_log<T>() );
00560 }
00561
00563
00565 T normalize( ProbNormType norm=dai::NORMPROB ) {
00566 T Z = 0;
00567 if( norm == dai::NORMPROB )
00568 Z = sum();
00569 else if( norm == dai::NORMLINF )
00570 Z = maxAbs();
00571 if( Z == (T)0 )
00572 DAI_THROW(NOT_NORMALIZABLE);
00573 else
00574 *this /= Z;
00575 return Z;
00576 }
00578
00580
00581
00582 this_type& fill( T x ) {
00583 std::fill( _p.begin(), _p.end(), x );
00584 return *this;
00585 }
00586
00588 this_type& operator+= (T x) {
00589 if( x != 0 )
00590 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
00591 else
00592 return *this;
00593 }
00594
00596 this_type& operator-= (T x) {
00597 if( x != 0 )
00598 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
00599 else
00600 return *this;
00601 }
00602
00604 this_type& operator*= (T x) {
00605 if( x != 1 )
00606 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
00607 else
00608 return *this;
00609 }
00610
00612 this_type& operator/= (T x) {
00613 if( x != 1 )
00614 return pwUnaryOp( std::bind2nd( fo_divides0<T>(), x ) );
00615 else
00616 return *this;
00617 }
00618
00620 this_type& operator^= (T x) {
00621 if( x != (T)1 )
00622 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
00623 else
00624 return *this;
00625 }
00627
00629
00630
00631 this_type operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
00632
00634 this_type operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
00635
00637 this_type operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
00638
00640 this_type operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
00641
00643 this_type operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
00645
00647
00648
00649
00653 template<typename binaryOp> this_type& pwBinaryOp( const this_type &q, binaryOp op ) {
00654 DAI_DEBASSERT( size() == q.size() );
00655 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
00656 return *this;
00657 }
00658
00660
00662 this_type& operator+= (const this_type & q) { return pwBinaryOp( q, std::plus<T>() ); }
00663
00665
00667 this_type& operator-= (const this_type & q) { return pwBinaryOp( q, std::minus<T>() ); }
00668
00670
00672 this_type& operator*= (const this_type & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
00673
00675
00678 this_type& operator/= (const this_type & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
00679
00681
00684 this_type& divide (const this_type & q) { return pwBinaryOp( q, std::divides<T>() ); }
00685
00687
00689 this_type& operator^= (const this_type & q) { return pwBinaryOp( q, fo_pow<T>() ); }
00691
00693
00694
00695
00699 template<typename binaryOp> this_type pwBinaryTr( const this_type &q, binaryOp op ) const {
00700 DAI_DEBASSERT( size() == q.size() );
00701 TProb<T> r;
00702 r._p.reserve( size() );
00703 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
00704 return r;
00705 }
00706
00708
00710 this_type operator+ ( const this_type& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
00711
00713
00715 this_type operator- ( const this_type& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
00716
00718
00720 this_type operator* ( const this_type &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
00721
00723
00726 this_type operator/ ( const this_type &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
00727
00729
00732 this_type divided_by( const this_type &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
00733
00735
00737 this_type operator^ ( const this_type &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
00739
00741
00743 template<typename binOp1, typename binOp2> T innerProduct( const this_type &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
00744 DAI_DEBASSERT( size() == q.size() );
00745 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
00746 }
00747 };
00748
00749
00751
00754 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, ProbDistType dt ) {
00755 switch( dt ) {
00756 case DISTL1:
00757 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
00758 case DISTLINF:
00759 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
00760 case DISTTV:
00761 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
00762 case DISTKL:
00763 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
00764 case DISTHEL:
00765 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2;
00766 default:
00767 DAI_THROW(UNKNOWN_ENUM_VALUE);
00768 return INFINITY;
00769 }
00770 }
00771
00772
00774
00776 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
00777 os << "(";
00778 for( size_t i = 0; i < p.size(); i++ )
00779 os << ((i != 0) ? ", " : "") << p.get(i);
00780 os << ")";
00781 return os;
00782 }
00783
00784
00786
00789 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
00790 return a.pwBinaryTr( b, fo_min<T>() );
00791 }
00792
00793
00795
00798 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
00799 return a.pwBinaryTr( b, fo_max<T>() );
00800 }
00801
00802
00804 typedef TProb<Real> Prob;
00805
00806
00807 }
00808
00809
00810 #endif