libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00011 00012 00013 #ifndef __defined_libdai_factor_h 00014 #define __defined_libdai_factor_h 00015 00016 00017 #include <iostream> 00018 #include <functional> 00019 #include <cmath> 00020 #include <dai/prob.h> 00021 #include <dai/varset.h> 00022 #include <dai/index.h> 00023 #include <dai/util.h> 00024 00025 00026 namespace dai { 00027 00028 00030 00054 template <typename T> 00055 class TFactor { 00056 private: 00058 VarSet _vs; 00060 TProb<T> _p; 00061 00062 public: 00064 00065 00066 TFactor ( T p = 1 ) : _vs(), _p(1,p) {} 00067 00069 TFactor( const Var &v ) : _vs(v), _p(v.states()) {} 00070 00072 TFactor( const VarSet& vars ) : _vs(vars), _p() { 00073 _p = TProb<T>( BigInt_size_t( _vs.nrStates() ) ); 00074 } 00075 00077 TFactor( const VarSet& vars, T p ) : _vs(vars), _p() { 00078 _p = TProb<T>( BigInt_size_t( _vs.nrStates() ), p ); 00079 } 00080 00082 00086 template<typename S> 00087 TFactor( const VarSet& vars, const std::vector<S> &x ) : _vs(vars), _p() { 00088 DAI_ASSERT( x.size() == vars.nrStates() ); 00089 _p = TProb<T>( x.begin(), x.end(), x.size() ); 00090 } 00091 00093 00096 TFactor( const VarSet& vars, const T* p ) : _vs(vars), _p() { 00097 size_t N = BigInt_size_t( _vs.nrStates() ); 00098 _p = TProb<T>( p, p + N, N ); 00099 } 00100 00102 TFactor( const VarSet& vars, const TProb<T> &p ) : _vs(vars), _p(p) { 00103 DAI_ASSERT( _vs.nrStates() == _p.size() ); 00104 } 00105 00107 TFactor( const std::vector<Var> &vars, const std::vector<T> &p ) : _vs(vars.begin(), vars.end(), vars.size()), _p(p.size()) { 00108 BigInt nrStates = 1; 00109 for( size_t i = 0; i < vars.size(); i++ ) 00110 nrStates *= vars[i].states(); 00111 DAI_ASSERT( nrStates == p.size() ); 00112 Permute permindex(vars); 00113 for( size_t li = 0; li < p.size(); ++li ) 00114 _p.set( permindex.convertLinearIndex(li), p[li] ); 00115 } 00117 00119 00120 00121 void set( size_t i, T val ) { _p.set( i, val ); } 00122 00124 T get( size_t i ) const { return _p[i]; } 00126 00128 00129 00130 const TProb<T>& p() const { return _p; } 00131 00133 TProb<T>& p() { return _p; } 00134 00136 T operator[] (size_t i) const { return _p[i]; } 00137 00139 const VarSet& vars() const { return _vs; } 00140 00142 VarSet& vars() { return _vs; } 00143 00145 00147 size_t nrStates() const { return _p.size(); } 00148 00150 T entropy() const { return _p.entropy(); } 00151 00153 T max() const { return _p.max(); } 00154 00156 T min() const { return _p.min(); } 00157 00159 T sum() const { return _p.sum(); } 00160 00162 T sumAbs() const { return _p.sumAbs(); } 00163 00165 T maxAbs() const { return _p.maxAbs(); } 00166 00168 bool hasNaNs() const { return _p.hasNaNs(); } 00169 00171 bool hasNegatives() const { return _p.hasNegatives(); } 00172 00174 T strength( const Var &i, const Var &j ) const; 00175 00177 bool operator==( const TFactor<T>& y ) const { 00178 return (_vs == y._vs) && (_p == y._p); 00179 } 00181 00183 00184 00185 TFactor<T> operator- () const { 00186 // Note: the alternative (shorter) way of implementing this, 00187 // return TFactor<T>( _vs, _p.abs() ); 00188 // is slower because it invokes the copy constructor of TProb<T> 00189 TFactor<T> x; 00190 x._vs = _vs; 00191 x._p = -_p; 00192 return x; 00193 } 00194 00196 TFactor<T> abs() const { 00197 TFactor<T> x; 00198 x._vs = _vs; 00199 x._p = _p.abs(); 00200 return x; 00201 } 00202 00204 TFactor<T> exp() const { 00205 TFactor<T> x; 00206 x._vs = _vs; 00207 x._p = _p.exp(); 00208 return x; 00209 } 00210 00212 00214 TFactor<T> log(bool zero=false) const { 00215 TFactor<T> x; 00216 x._vs = _vs; 00217 x._p = _p.log(zero); 00218 return x; 00219 } 00220 00222 00224 TFactor<T> inverse(bool zero=true) const { 00225 TFactor<T> x; 00226 x._vs = _vs; 00227 x._p = _p.inverse(zero); 00228 return x; 00229 } 00230 00232 00234 TFactor<T> normalized( ProbNormType norm=NORMPROB ) const { 00235 TFactor<T> x; 00236 x._vs = _vs; 00237 x._p = _p.normalized( norm ); 00238 return x; 00239 } 00241 00243 00244 00245 TFactor<T>& randomize() { _p.randomize(); return *this; } 00246 00248 TFactor<T>& setUniform() { _p.setUniform(); return *this; } 00249 00251 TFactor<T>& takeAbs() { _p.takeAbs(); return *this; } 00252 00254 TFactor<T>& takeExp() { _p.takeExp(); return *this; } 00255 00257 00259 TFactor<T>& takeLog( bool zero = false ) { _p.takeLog(zero); return *this; } 00260 00262 00264 T normalize( ProbNormType norm=NORMPROB ) { return _p.normalize( norm ); } 00266 00268 00269 00270 TFactor<T>& fill (T x) { _p.fill( x ); return *this; } 00271 00273 TFactor<T>& operator+= (T x) { _p += x; return *this; } 00274 00276 TFactor<T>& operator-= (T x) { _p -= x; return *this; } 00277 00279 TFactor<T>& operator*= (T x) { _p *= x; return *this; } 00280 00282 TFactor<T>& operator/= (T x) { _p /= x; return *this; } 00283 00285 TFactor<T>& operator^= (T x) { _p ^= x; return *this; } 00287 00289 00290 00291 TFactor<T> operator+ (T x) const { 00292 // Note: the alternative (shorter) way of implementing this, 00293 // TFactor<T> result(*this); 00294 // result._p += x; 00295 // is slower because it invokes the copy constructor of TFactor<T> 00296 TFactor<T> result; 00297 result._vs = _vs; 00298 result._p = p() + x; 00299 return result; 00300 } 00301 00303 TFactor<T> operator- (T x) const { 00304 TFactor<T> result; 00305 result._vs = _vs; 00306 result._p = p() - x; 00307 return result; 00308 } 00309 00311 TFactor<T> operator* (T x) const { 00312 TFactor<T> result; 00313 result._vs = _vs; 00314 result._p = p() * x; 00315 return result; 00316 } 00317 00319 TFactor<T> operator/ (T x) const { 00320 TFactor<T> result; 00321 result._vs = _vs; 00322 result._p = p() / x; 00323 return result; 00324 } 00325 00327 TFactor<T> operator^ (T x) const { 00328 TFactor<T> result; 00329 result._vs = _vs; 00330 result._p = p() ^ x; 00331 return result; 00332 } 00334 00336 00337 00338 00342 template<typename binOp> TFactor<T>& binaryOp( const TFactor<T> &g, binOp op ) { 00343 if( _vs == g._vs ) // optimize special case 00344 _p.pwBinaryOp( g._p, op ); 00345 else { 00346 TFactor<T> f(*this); // make a copy 00347 _vs |= g._vs; 00348 size_t N = BigInt_size_t( _vs.nrStates() ); 00349 00350 IndexFor i_f( f._vs, _vs ); 00351 IndexFor i_g( g._vs, _vs ); 00352 00353 _p.p().clear(); 00354 _p.p().reserve( N ); 00355 for( size_t i = 0; i < N; i++, ++i_f, ++i_g ) 00356 _p.p().push_back( op( f._p[i_f], g._p[i_g] ) ); 00357 } 00358 return *this; 00359 } 00360 00362 00366 TFactor<T>& operator+= (const TFactor<T>& g) { return binaryOp( g, std::plus<T>() ); } 00367 00369 00373 TFactor<T>& operator-= (const TFactor<T>& g) { return binaryOp( g, std::minus<T>() ); } 00374 00376 00380 TFactor<T>& operator*= (const TFactor<T>& g) { return binaryOp( g, std::multiplies<T>() ); } 00381 00383 00387 TFactor<T>& operator/= (const TFactor<T>& g) { return binaryOp( g, fo_divides0<T>() ); } 00389 00391 00392 00393 00397 template<typename binOp> TFactor<T> binaryTr( const TFactor<T> &g, binOp op ) const { 00398 // Note that to prevent a copy to be made, it is crucial 00399 // that the result is declared outside the if-else construct. 00400 TFactor<T> result; 00401 if( _vs == g._vs ) { // optimize special case 00402 result._vs = _vs; 00403 result._p = _p.pwBinaryTr( g._p, op ); 00404 } else { 00405 result._vs = _vs | g._vs; 00406 size_t N = BigInt_size_t( result._vs.nrStates() ); 00407 00408 IndexFor i_f( _vs, result.vars() ); 00409 IndexFor i_g( g._vs, result.vars() ); 00410 00411 result._p.p().clear(); 00412 result._p.p().reserve( N ); 00413 for( size_t i = 0; i < N; i++, ++i_f, ++i_g ) 00414 result._p.p().push_back( op( _p[i_f], g[i_g] ) ); 00415 } 00416 return result; 00417 } 00418 00420 00424 TFactor<T> operator+ (const TFactor<T>& g) const { 00425 return binaryTr(g,std::plus<T>()); 00426 } 00427 00429 00433 TFactor<T> operator- (const TFactor<T>& g) const { 00434 return binaryTr(g,std::minus<T>()); 00435 } 00436 00438 00442 TFactor<T> operator* (const TFactor<T>& g) const { 00443 return binaryTr(g,std::multiplies<T>()); 00444 } 00445 00447 00451 TFactor<T> operator/ (const TFactor<T>& g) const { 00452 return binaryTr(g,fo_divides0<T>()); 00453 } 00455 00457 00458 00459 00470 TFactor<T> slice( const VarSet& vars, size_t varsState ) const; 00471 00473 00478 TFactor<T> embed(const VarSet & vars) const { 00479 DAI_ASSERT( vars >> _vs ); 00480 if( _vs == vars ) 00481 return *this; 00482 else 00483 return (*this) * TFactor<T>(vars / _vs, (T)1); 00484 } 00485 00487 TFactor<T> marginal(const VarSet &vars, bool normed=true) const; 00488 00490 TFactor<T> maxMarginal(const VarSet &vars, bool normed=true) const; 00492 }; 00493 00494 00495 template<typename T> TFactor<T> TFactor<T>::slice( const VarSet& vars, size_t varsState ) const { 00496 DAI_ASSERT( vars << _vs ); 00497 VarSet varsrem = _vs / vars; 00498 TFactor<T> result( varsrem, T(0) ); 00499 00500 // OPTIMIZE ME 00501 IndexFor i_vars (vars, _vs); 00502 IndexFor i_varsrem (varsrem, _vs); 00503 for( size_t i = 0; i < nrStates(); i++, ++i_vars, ++i_varsrem ) 00504 if( (size_t)i_vars == varsState ) 00505 result.set( i_varsrem, _p[i] ); 00506 00507 return result; 00508 } 00509 00510 00511 template<typename T> TFactor<T> TFactor<T>::marginal(const VarSet &vars, bool normed) const { 00512 VarSet res_vars = vars & _vs; 00513 00514 TFactor<T> res( res_vars, 0.0 ); 00515 00516 IndexFor i_res( res_vars, _vs ); 00517 for( size_t i = 0; i < _p.size(); i++, ++i_res ) 00518 res.set( i_res, res[i_res] + _p[i] ); 00519 00520 if( normed ) 00521 res.normalize( NORMPROB ); 00522 00523 return res; 00524 } 00525 00526 00527 template<typename T> TFactor<T> TFactor<T>::maxMarginal(const VarSet &vars, bool normed) const { 00528 VarSet res_vars = vars & _vs; 00529 00530 TFactor<T> res( res_vars, 0.0 ); 00531 00532 IndexFor i_res( res_vars, _vs ); 00533 for( size_t i = 0; i < _p.size(); i++, ++i_res ) 00534 if( _p[i] > res._p[i_res] ) 00535 res.set( i_res, _p[i] ); 00536 00537 if( normed ) 00538 res.normalize( NORMPROB ); 00539 00540 return res; 00541 } 00542 00543 00544 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const { 00545 DAI_DEBASSERT( _vs.contains( i ) ); 00546 DAI_DEBASSERT( _vs.contains( j ) ); 00547 DAI_DEBASSERT( i != j ); 00548 VarSet ij(i, j); 00549 00550 T max = 0.0; 00551 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ ) 00552 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ ) 00553 if( alpha2 != alpha1 ) 00554 for( size_t beta1 = 0; beta1 < j.states(); beta1++ ) 00555 for( size_t beta2 = 0; beta2 < j.states(); beta2++ ) 00556 if( beta2 != beta1 ) { 00557 size_t as = 1, bs = 1; 00558 if( i < j ) 00559 bs = i.states(); 00560 else 00561 as = j.states(); 00562 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).max(); 00563 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).max(); 00564 T f = f1 * f2; 00565 if( f > max ) 00566 max = f; 00567 } 00568 00569 return std::tanh( 0.25 * std::log( max ) ); 00570 } 00571 00572 00574 00576 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& f) { 00577 os << "(" << f.vars() << ", ("; 00578 for( size_t i = 0; i < f.nrStates(); i++ ) 00579 os << (i == 0 ? "" : ", ") << f[i]; 00580 os << "))"; 00581 return os; 00582 } 00583 00584 00586 00589 template<typename T> T dist( const TFactor<T> &f, const TFactor<T> &g, ProbDistType dt ) { 00590 if( f.vars().empty() || g.vars().empty() ) 00591 return -1; 00592 else { 00593 DAI_DEBASSERT( f.vars() == g.vars() ); 00594 return dist( f.p(), g.p(), dt ); 00595 } 00596 } 00597 00598 00600 00603 template<typename T> TFactor<T> max( const TFactor<T> &f, const TFactor<T> &g ) { 00604 DAI_ASSERT( f.vars() == g.vars() ); 00605 return TFactor<T>( f.vars(), max( f.p(), g.p() ) ); 00606 } 00607 00608 00610 00613 template<typename T> TFactor<T> min( const TFactor<T> &f, const TFactor<T> &g ) { 00614 DAI_ASSERT( f.vars() == g.vars() ); 00615 return TFactor<T>( f.vars(), min( f.p(), g.p() ) ); 00616 } 00617 00618 00620 00623 template<typename T> T MutualInfo(const TFactor<T> &f) { 00624 DAI_ASSERT( f.vars().size() == 2 ); 00625 VarSet::const_iterator it = f.vars().begin(); 00626 Var i = *it; it++; Var j = *it; 00627 TFactor<T> projection = f.marginal(i) * f.marginal(j); 00628 return dist( f.normalized(), projection, DISTKL ); 00629 } 00630 00631 00633 typedef TFactor<Real> Factor; 00634 00635 00637 00640 Factor createFactorIsing( const Var &x, Real h ); 00641 00642 00644 00648 Factor createFactorIsing( const Var &x1, const Var &x2, Real J ); 00649 00650 00652 00657 Factor createFactorExpGauss( const VarSet &vs, Real beta ); 00658 00659 00661 00665 Factor createFactorPotts( const Var &x1, const Var &x2, Real J ); 00666 00667 00669 00672 Factor createFactorDelta( const Var &v, size_t state ); 00673 00674 00676 00679 Factor createFactorDelta( const VarSet& vs, size_t state ); 00680 00681 00682 } // end of namespace dai 00683 00684 00685 #endif