libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00011 00012 00013 #ifndef __defined_libdai_prob_h 00014 #define __defined_libdai_prob_h 00015 00016 00017 #include <cmath> 00018 #include <vector> 00019 #include <ostream> 00020 #include <algorithm> 00021 #include <numeric> 00022 #include <functional> 00023 #include <dai/util.h> 00024 #include <dai/exceptions.h> 00025 00026 00027 namespace dai { 00028 00029 00031 template<typename T> struct fo_id : public std::unary_function<T, T> { 00033 T operator()( const T &x ) const { 00034 return x; 00035 } 00036 }; 00037 00038 00040 template<typename T> struct fo_abs : public std::unary_function<T, T> { 00042 T operator()( const T &x ) const { 00043 if( x < (T)0 ) 00044 return -x; 00045 else 00046 return x; 00047 } 00048 }; 00049 00050 00052 template<typename T> struct fo_exp : public std::unary_function<T, T> { 00054 T operator()( const T &x ) const { 00055 return exp( x ); 00056 } 00057 }; 00058 00059 00061 template<typename T> struct fo_log : public std::unary_function<T, T> { 00063 T operator()( const T &x ) const { 00064 return log( x ); 00065 } 00066 }; 00067 00068 00070 template<typename T> struct fo_log0 : public std::unary_function<T, T> { 00072 T operator()( const T &x ) const { 00073 if( x ) 00074 return log( x ); 00075 else 00076 return 0; 00077 } 00078 }; 00079 00080 00082 template<typename T> struct fo_inv : public std::unary_function<T, T> { 00084 T operator()( const T &x ) const { 00085 return 1 / x; 00086 } 00087 }; 00088 00089 00091 template<typename T> struct fo_inv0 : public std::unary_function<T, T> { 00093 T operator()( const T &x ) const { 00094 if( x ) 00095 return 1 / x; 00096 else 00097 return 0; 00098 } 00099 }; 00100 00101 00103 template<typename T> struct fo_plog0p : public std::unary_function<T, T> { 00105 T operator()( const T &p ) const { 00106 return p * dai::log0(p); 00107 } 00108 }; 00109 00110 00112 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> { 00114 T operator()( const T &x, const T &y ) const { 00115 if( y == (T)0 ) 00116 return (T)0; 00117 else 00118 return x / y; 00119 } 00120 }; 00121 00122 00124 template<typename T> struct fo_KL : public std::binary_function<T, T, T> { 00126 T operator()( const T &p, const T &q ) const { 00127 if( p == (T)0 ) 00128 return (T)0; 00129 else 00130 return p * (log(p) - log(q)); 00131 } 00132 }; 00133 00134 00136 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> { 00138 T operator()( const T &p, const T &q ) const { 00139 T x = sqrt(p) - sqrt(q); 00140 return x * x; 00141 } 00142 }; 00143 00144 00146 template<typename T> struct fo_pow : public std::binary_function<T, T, T> { 00148 T operator()( const T &x, const T &y ) const { 00149 if( y != 1 ) 00150 return pow( x, y ); 00151 else 00152 return x; 00153 } 00154 }; 00155 00156 00158 template<typename T> struct fo_max : public std::binary_function<T, T, T> { 00160 T operator()( const T &x, const T &y ) const { 00161 return (x > y) ? x : y; 00162 } 00163 }; 00164 00165 00167 template<typename T> struct fo_min : public std::binary_function<T, T, T> { 00169 T operator()( const T &x, const T &y ) const { 00170 return (x > y) ? y : x; 00171 } 00172 }; 00173 00174 00176 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> { 00178 T operator()( const T &x, const T &y ) const { 00179 return dai::abs( x - y ); 00180 } 00181 }; 00182 00183 00185 00193 template <typename T> 00194 class TProb { 00195 public: 00197 typedef std::vector<T> container_type; 00198 00200 typedef TProb<T> this_type; 00201 00202 private: 00204 container_type _p; 00205 00206 public: 00208 00209 00210 TProb() : _p() {} 00211 00213 explicit TProb( size_t n ) : _p( n, (T)1 / n ) {} 00214 00216 explicit TProb( size_t n, T p ) : _p( n, p ) {} 00217 00219 00225 template <typename TIterator> 00226 TProb( TIterator begin, TIterator end, size_t sizeHint ) : _p() { 00227 _p.reserve( sizeHint ); 00228 _p.insert( _p.begin(), begin, end ); 00229 } 00230 00232 00235 template <typename S> 00236 TProb( const std::vector<S> &v ) : _p() { 00237 _p.reserve( v.size() ); 00238 _p.insert( _p.begin(), v.begin(), v.end() ); 00239 } 00241 00243 typedef typename container_type::const_iterator const_iterator; 00245 typedef typename container_type::iterator iterator; 00247 typedef typename container_type::const_reverse_iterator const_reverse_iterator; 00249 typedef typename container_type::reverse_iterator reverse_iterator; 00250 00252 00253 00254 iterator begin() { return _p.begin(); } 00256 const_iterator begin() const { return _p.begin(); } 00257 00259 iterator end() { return _p.end(); } 00261 const_iterator end() const { return _p.end(); } 00262 00264 reverse_iterator rbegin() { return _p.rbegin(); } 00266 const_reverse_iterator rbegin() const { return _p.rbegin(); } 00267 00269 reverse_iterator rend() { return _p.rend(); } 00271 const_reverse_iterator rend() const { return _p.rend(); } 00273 00275 00276 void resize( size_t sz ) { 00277 _p.resize( sz ); 00278 } 00280 00282 00283 00284 T get( size_t i ) const { 00285 #ifdef DAI_DEBUG 00286 return _p.at(i); 00287 #else 00288 return _p[i]; 00289 #endif 00290 } 00291 00293 void set( size_t i, T val ) { 00294 DAI_DEBASSERT( i < _p.size() ); 00295 _p[i] = val; 00296 } 00298 00300 00301 00302 const container_type& p() const { return _p; } 00303 00305 container_type& p() { return _p; } 00306 00308 T operator[]( size_t i ) const { return get(i); } 00309 00311 size_t size() const { return _p.size(); } 00312 00314 00322 template<typename unOp> T accumulateSum( T init, unOp op ) const { 00323 T t = op(init); 00324 for( const_iterator it = begin(); it != end(); it++ ) 00325 t += op(*it); 00326 return t; 00327 } 00328 00330 00338 template<typename unOp> T accumulateMax( T init, unOp op, bool minimize ) const { 00339 T t = op(init); 00340 if( minimize ) { 00341 for( const_iterator it = begin(); it != end(); it++ ) 00342 t = std::min( t, op(*it) ); 00343 } else { 00344 for( const_iterator it = begin(); it != end(); it++ ) 00345 t = std::max( t, op(*it) ); 00346 } 00347 return t; 00348 } 00349 00351 T entropy() const { return -accumulateSum( (T)0, fo_plog0p<T>() ); } 00352 00354 T max() const { return accumulateMax( (T)(-INFINITY), fo_id<T>(), false ); } 00355 00357 T min() const { return accumulateMax( (T)INFINITY, fo_id<T>(), true ); } 00358 00360 T sum() const { return accumulateSum( (T)0, fo_id<T>() ); } 00361 00363 T sumAbs() const { return accumulateSum( (T)0, fo_abs<T>() ); } 00364 00366 T maxAbs() const { return accumulateMax( (T)0, fo_abs<T>(), false ); } 00367 00369 bool hasNaNs() const { 00370 bool foundnan = false; 00371 for( const_iterator x = _p.begin(); x != _p.end(); x++ ) 00372 if( dai::isnan( *x ) ) { 00373 foundnan = true; 00374 break; 00375 } 00376 return foundnan; 00377 } 00378 00380 bool hasNegatives() const { 00381 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end()); 00382 } 00383 00385 std::pair<size_t,T> argmax() const { 00386 T max = _p[0]; 00387 size_t arg = 0; 00388 for( size_t i = 1; i < size(); i++ ) { 00389 if( _p[i] > max ) { 00390 max = _p[i]; 00391 arg = i; 00392 } 00393 } 00394 return std::make_pair( arg, max ); 00395 } 00396 00398 size_t draw() { 00399 Real x = rnd_uniform() * sum(); 00400 T s = 0; 00401 for( size_t i = 0; i < size(); i++ ) { 00402 s += get(i); 00403 if( s > x ) 00404 return i; 00405 } 00406 return( size() - 1 ); 00407 } 00408 00410 00412 bool operator<( const this_type& q ) const { 00413 DAI_DEBASSERT( size() == q.size() ); 00414 return lexicographical_compare( begin(), end(), q.begin(), q.end() ); 00415 } 00416 00418 bool operator==( const this_type& q ) const { 00419 if( size() != q.size() ) 00420 return false; 00421 return p() == q.p(); 00422 } 00424 00426 00427 00428 template<typename unaryOp> this_type pwUnaryTr( unaryOp op ) const { 00429 this_type r; 00430 r._p.reserve( size() ); 00431 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op ); 00432 return r; 00433 } 00434 00436 this_type operator- () const { return pwUnaryTr( std::negate<T>() ); } 00437 00439 this_type abs() const { return pwUnaryTr( fo_abs<T>() ); } 00440 00442 this_type exp() const { return pwUnaryTr( fo_exp<T>() ); } 00443 00445 00447 this_type log(bool zero=false) const { 00448 if( zero ) 00449 return pwUnaryTr( fo_log0<T>() ); 00450 else 00451 return pwUnaryTr( fo_log<T>() ); 00452 } 00453 00455 00457 this_type inverse(bool zero=true) const { 00458 if( zero ) 00459 return pwUnaryTr( fo_inv0<T>() ); 00460 else 00461 return pwUnaryTr( fo_inv<T>() ); 00462 } 00463 00465 00467 this_type normalized( ProbNormType norm = dai::NORMPROB ) const { 00468 T Z = 0; 00469 if( norm == dai::NORMPROB ) 00470 Z = sum(); 00471 else if( norm == dai::NORMLINF ) 00472 Z = maxAbs(); 00473 if( Z == (T)0 ) { 00474 DAI_THROW(NOT_NORMALIZABLE); 00475 return *this; 00476 } else 00477 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) ); 00478 } 00480 00482 00483 00484 template<typename unaryOp> this_type& pwUnaryOp( unaryOp op ) { 00485 std::transform( _p.begin(), _p.end(), _p.begin(), op ); 00486 return *this; 00487 } 00488 00490 this_type& randomize() { 00491 std::generate( _p.begin(), _p.end(), rnd_uniform ); 00492 return *this; 00493 } 00494 00496 this_type& setUniform () { 00497 fill( (T)1 / size() ); 00498 return *this; 00499 } 00500 00502 this_type& takeAbs() { return pwUnaryOp( fo_abs<T>() ); } 00503 00505 this_type& takeExp() { return pwUnaryOp( fo_exp<T>() ); } 00506 00508 00510 this_type& takeLog(bool zero=false) { 00511 if( zero ) { 00512 return pwUnaryOp( fo_log0<T>() ); 00513 } else 00514 return pwUnaryOp( fo_log<T>() ); 00515 } 00516 00518 00520 T normalize( ProbNormType norm=dai::NORMPROB ) { 00521 T Z = 0; 00522 if( norm == dai::NORMPROB ) 00523 Z = sum(); 00524 else if( norm == dai::NORMLINF ) 00525 Z = maxAbs(); 00526 if( Z == (T)0 ) 00527 DAI_THROW(NOT_NORMALIZABLE); 00528 else 00529 *this /= Z; 00530 return Z; 00531 } 00533 00535 00536 00537 this_type& fill( T x ) { 00538 std::fill( _p.begin(), _p.end(), x ); 00539 return *this; 00540 } 00541 00543 this_type& operator+= (T x) { 00544 if( x != 0 ) 00545 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) ); 00546 else 00547 return *this; 00548 } 00549 00551 this_type& operator-= (T x) { 00552 if( x != 0 ) 00553 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) ); 00554 else 00555 return *this; 00556 } 00557 00559 this_type& operator*= (T x) { 00560 if( x != 1 ) 00561 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) ); 00562 else 00563 return *this; 00564 } 00565 00567 this_type& operator/= (T x) { 00568 if( x != 1 ) 00569 return pwUnaryOp( std::bind2nd( fo_divides0<T>(), x ) ); 00570 else 00571 return *this; 00572 } 00573 00575 this_type& operator^= (T x) { 00576 if( x != (T)1 ) 00577 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) ); 00578 else 00579 return *this; 00580 } 00582 00584 00585 00586 this_type operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); } 00587 00589 this_type operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); } 00590 00592 this_type operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); } 00593 00595 this_type operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); } 00596 00598 this_type operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); } 00600 00602 00603 00604 00608 template<typename binaryOp> this_type& pwBinaryOp( const this_type &q, binaryOp op ) { 00609 DAI_DEBASSERT( size() == q.size() ); 00610 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op ); 00611 return *this; 00612 } 00613 00615 00617 this_type& operator+= (const this_type & q) { return pwBinaryOp( q, std::plus<T>() ); } 00618 00620 00622 this_type& operator-= (const this_type & q) { return pwBinaryOp( q, std::minus<T>() ); } 00623 00625 00627 this_type& operator*= (const this_type & q) { return pwBinaryOp( q, std::multiplies<T>() ); } 00628 00630 00633 this_type& operator/= (const this_type & q) { return pwBinaryOp( q, fo_divides0<T>() ); } 00634 00636 00639 this_type& divide (const this_type & q) { return pwBinaryOp( q, std::divides<T>() ); } 00640 00642 00644 this_type& operator^= (const this_type & q) { return pwBinaryOp( q, fo_pow<T>() ); } 00646 00648 00649 00650 00654 template<typename binaryOp> this_type pwBinaryTr( const this_type &q, binaryOp op ) const { 00655 DAI_DEBASSERT( size() == q.size() ); 00656 TProb<T> r; 00657 r._p.reserve( size() ); 00658 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op ); 00659 return r; 00660 } 00661 00663 00665 this_type operator+ ( const this_type& q ) const { return pwBinaryTr( q, std::plus<T>() ); } 00666 00668 00670 this_type operator- ( const this_type& q ) const { return pwBinaryTr( q, std::minus<T>() ); } 00671 00673 00675 this_type operator* ( const this_type &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); } 00676 00678 00681 this_type operator/ ( const this_type &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); } 00682 00684 00687 this_type divided_by( const this_type &q ) const { return pwBinaryTr( q, std::divides<T>() ); } 00688 00690 00692 this_type operator^ ( const this_type &q ) const { return pwBinaryTr( q, fo_pow<T>() ); } 00694 00696 00698 template<typename binOp1, typename binOp2> T innerProduct( const this_type &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const { 00699 DAI_DEBASSERT( size() == q.size() ); 00700 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 ); 00701 } 00702 }; 00703 00704 00706 00709 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, ProbDistType dt ) { 00710 switch( dt ) { 00711 case DISTL1: 00712 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ); 00713 case DISTLINF: 00714 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() ); 00715 case DISTTV: 00716 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2; 00717 case DISTKL: 00718 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() ); 00719 case DISTHEL: 00720 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2; 00721 default: 00722 DAI_THROW(UNKNOWN_ENUM_VALUE); 00723 return INFINITY; 00724 } 00725 } 00726 00727 00729 00731 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) { 00732 os << "("; 00733 for( size_t i = 0; i < p.size(); i++ ) 00734 os << ((i != 0) ? ", " : "") << p.get(i); 00735 os << ")"; 00736 return os; 00737 } 00738 00739 00741 00744 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) { 00745 return a.pwBinaryTr( b, fo_min<T>() ); 00746 } 00747 00748 00750 00753 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) { 00754 return a.pwBinaryTr( b, fo_max<T>() ); 00755 } 00756 00757 00759 typedef TProb<Real> Prob; 00760 00761 00762 } // end of namespace dai 00763 00764 00765 #endif