00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00015
00016
00017 #ifndef __defined_libdai_factor_h
00018 #define __defined_libdai_factor_h
00019
00020
00021 #include <iostream>
00022 #include <functional>
00023 #include <cmath>
00024 #include <dai/prob.h>
00025 #include <dai/varset.h>
00026 #include <dai/index.h>
00027 #include <dai/util.h>
00028
00029
00030 namespace dai {
00031
00032
00034
00058 template <typename T>
00059 class TFactor {
00060 private:
00062 VarSet _vs;
00064 TProb<T> _p;
00065
00066 public:
00068
00069
00070 TFactor ( T p = 1 ) : _vs(), _p(1,p) {}
00071
00073 TFactor( const Var &v ) : _vs(v), _p(v.states()) {}
00074
00076 TFactor( const VarSet& vars ) : _vs(vars), _p(_vs.nrStates()) {}
00077
00079 TFactor( const VarSet& vars, T p ) : _vs(vars), _p(_vs.nrStates(),p) {}
00080
00082
00086 template<typename S>
00087 TFactor( const VarSet& vars, const std::vector<S> &x ) : _vs(vars), _p() {
00088 DAI_ASSERT( x.size() == vars.nrStates() );
00089 _p = TProb<T>( x.begin(), x.begin() + _vs.nrStates(), _vs.nrStates() );
00090 }
00091
00093
00096 TFactor( const VarSet& vars, const T* p ) : _vs(vars), _p(p, p + _vs.nrStates(), _vs.nrStates()) {}
00097
00099 TFactor( const VarSet& vars, const TProb<T> &p ) : _vs(vars), _p(p) {
00100 DAI_ASSERT( _vs.nrStates() == _p.size() );
00101 }
00102
00104 TFactor( const std::vector<Var> &vars, const std::vector<T> &p ) : _vs(vars.begin(), vars.end(), vars.size()), _p(p.size()) {
00105 size_t nrStates = 1;
00106 for( size_t i = 0; i < vars.size(); i++ )
00107 nrStates *= vars[i].states();
00108 DAI_ASSERT( nrStates == p.size() );
00109 Permute permindex(vars);
00110 for( size_t li = 0; li < p.size(); ++li )
00111 _p.set( permindex.convertLinearIndex(li), p[li] );
00112 }
00114
00116
00117
00118 void set( size_t i, T val ) { _p.set( i, val ); }
00119
00121 T get( size_t i ) const { return _p[i]; }
00123
00125
00126
00127 const TProb<T>& p() const { return _p; }
00128
00130 TProb<T>& p() { return _p; }
00131
00133 T operator[] (size_t i) const { return _p[i]; }
00134
00137 T& operator[] (size_t i) { return _p[i]; }
00138
00140 const VarSet& vars() const { return _vs; }
00141
00143 VarSet& vars() { return _vs; }
00144
00146
00148 size_t nrStates() const { return _p.size(); }
00149
00151
00154 size_t states() const { return _p.size(); }
00155
00157 T entropy() const { return _p.entropy(); }
00158
00160 T max() const { return _p.max(); }
00161
00163 T min() const { return _p.min(); }
00164
00166 T sum() const { return _p.sum(); }
00167
00169 T sumAbs() const { return _p.sumAbs(); }
00170
00172 T maxAbs() const { return _p.maxAbs(); }
00173
00175 bool hasNaNs() const { return _p.hasNaNs(); }
00176
00178 bool hasNegatives() const { return _p.hasNegatives(); }
00179
00181 T strength( const Var &i, const Var &j ) const;
00182
00184 bool operator==( const TFactor<T>& y ) const {
00185 return (_vs == y._vs) && (_p == y._p);
00186 }
00188
00190
00191
00192 TFactor<T> operator- () const {
00193
00194
00195
00196 TFactor<T> x;
00197 x._vs = _vs;
00198 x._p = -_p;
00199 return x;
00200 }
00201
00203 TFactor<T> abs() const {
00204 TFactor<T> x;
00205 x._vs = _vs;
00206 x._p = _p.abs();
00207 return x;
00208 }
00209
00211 TFactor<T> exp() const {
00212 TFactor<T> x;
00213 x._vs = _vs;
00214 x._p = _p.exp();
00215 return x;
00216 }
00217
00219
00221 TFactor<T> log(bool zero=false) const {
00222 TFactor<T> x;
00223 x._vs = _vs;
00224 x._p = _p.log(zero);
00225 return x;
00226 }
00227
00229
00231 TFactor<T> inverse(bool zero=true) const {
00232 TFactor<T> x;
00233 x._vs = _vs;
00234 x._p = _p.inverse(zero);
00235 return x;
00236 }
00237
00239
00241 TFactor<T> normalized( ProbNormType norm=NORMPROB ) const {
00242 TFactor<T> x;
00243 x._vs = _vs;
00244 x._p = _p.normalized( norm );
00245 return x;
00246 }
00248
00250
00251
00252 TFactor<T>& randomize() { _p.randomize(); return *this; }
00253
00255 TFactor<T>& setUniform() { _p.setUniform(); return *this; }
00256
00258 TFactor<T>& takeAbs() { _p.takeAbs(); return *this; }
00259
00261 TFactor<T>& takeExp() { _p.takeExp(); return *this; }
00262
00264
00266 TFactor<T>& takeLog( bool zero = false ) { _p.takeLog(zero); return *this; }
00267
00269
00271 T normalize( ProbNormType norm=NORMPROB ) { return _p.normalize( norm ); }
00273
00275
00276
00277 TFactor<T>& fill (T x) { _p.fill( x ); return *this; }
00278
00280 TFactor<T>& operator+= (T x) { _p += x; return *this; }
00281
00283 TFactor<T>& operator-= (T x) { _p -= x; return *this; }
00284
00286 TFactor<T>& operator*= (T x) { _p *= x; return *this; }
00287
00289 TFactor<T>& operator/= (T x) { _p /= x; return *this; }
00290
00292 TFactor<T>& operator^= (T x) { _p ^= x; return *this; }
00294
00296
00297
00298 TFactor<T> operator+ (T x) const {
00299
00300
00301
00302
00303 TFactor<T> result;
00304 result._vs = _vs;
00305 result._p = p() + x;
00306 return result;
00307 }
00308
00310 TFactor<T> operator- (T x) const {
00311 TFactor<T> result;
00312 result._vs = _vs;
00313 result._p = p() - x;
00314 return result;
00315 }
00316
00318 TFactor<T> operator* (T x) const {
00319 TFactor<T> result;
00320 result._vs = _vs;
00321 result._p = p() * x;
00322 return result;
00323 }
00324
00326 TFactor<T> operator/ (T x) const {
00327 TFactor<T> result;
00328 result._vs = _vs;
00329 result._p = p() / x;
00330 return result;
00331 }
00332
00334 TFactor<T> operator^ (T x) const {
00335 TFactor<T> result;
00336 result._vs = _vs;
00337 result._p = p() ^ x;
00338 return result;
00339 }
00341
00343
00344
00345
00349 template<typename binOp> TFactor<T>& binaryOp( const TFactor<T> &g, binOp op ) {
00350 if( _vs == g._vs )
00351 _p.pwBinaryOp( g._p, op );
00352 else {
00353 TFactor<T> f(*this);
00354 _vs |= g._vs;
00355 size_t N = _vs.nrStates();
00356
00357 IndexFor i_f( f._vs, _vs );
00358 IndexFor i_g( g._vs, _vs );
00359
00360 _p.p().clear();
00361 _p.p().reserve( N );
00362 for( size_t i = 0; i < N; i++, ++i_f, ++i_g )
00363 _p.p().push_back( op( f._p[i_f], g._p[i_g] ) );
00364 }
00365 return *this;
00366 }
00367
00369
00373 TFactor<T>& operator+= (const TFactor<T>& g) { return binaryOp( g, std::plus<T>() ); }
00374
00376
00380 TFactor<T>& operator-= (const TFactor<T>& g) { return binaryOp( g, std::minus<T>() ); }
00381
00383
00387 TFactor<T>& operator*= (const TFactor<T>& g) { return binaryOp( g, std::multiplies<T>() ); }
00388
00390
00394 TFactor<T>& operator/= (const TFactor<T>& g) { return binaryOp( g, fo_divides0<T>() ); }
00396
00398
00399
00400
00404 template<typename binOp> TFactor<T> binaryTr( const TFactor<T> &g, binOp op ) const {
00405
00406
00407 TFactor<T> result;
00408 if( _vs == g._vs ) {
00409 result._vs = _vs;
00410 result._p = _p.pwBinaryTr( g._p, op );
00411 } else {
00412 result._vs = _vs | g._vs;
00413 size_t N = result._vs.nrStates();
00414
00415 IndexFor i_f( _vs, result.vars() );
00416 IndexFor i_g( g._vs, result.vars() );
00417
00418 result._p.p().clear();
00419 result._p.p().reserve( N );
00420 for( size_t i = 0; i < N; i++, ++i_f, ++i_g )
00421 result._p.p().push_back( op( _p[i_f], g[i_g] ) );
00422 }
00423 return result;
00424 }
00425
00427
00431 TFactor<T> operator+ (const TFactor<T>& g) const {
00432 return binaryTr(g,std::plus<T>());
00433 }
00434
00436
00440 TFactor<T> operator- (const TFactor<T>& g) const {
00441 return binaryTr(g,std::minus<T>());
00442 }
00443
00445
00449 TFactor<T> operator* (const TFactor<T>& g) const {
00450 return binaryTr(g,std::multiplies<T>());
00451 }
00452
00454
00458 TFactor<T> operator/ (const TFactor<T>& g) const {
00459 return binaryTr(g,fo_divides0<T>());
00460 }
00462
00464
00465
00466
00477 TFactor<T> slice( const VarSet& vars, size_t varsState ) const;
00478
00480
00485 TFactor<T> embed(const VarSet & vars) const {
00486 DAI_ASSERT( vars >> _vs );
00487 if( _vs == vars )
00488 return *this;
00489 else
00490 return (*this) * TFactor<T>(vars / _vs, (T)1);
00491 }
00492
00494 TFactor<T> marginal(const VarSet &vars, bool normed=true) const;
00495
00497 TFactor<T> maxMarginal(const VarSet &vars, bool normed=true) const;
00499 };
00500
00501
00502 template<typename T> TFactor<T> TFactor<T>::slice( const VarSet& vars, size_t varsState ) const {
00503 DAI_ASSERT( vars << _vs );
00504 VarSet varsrem = _vs / vars;
00505 TFactor<T> result( varsrem, T(0) );
00506
00507
00508 IndexFor i_vars (vars, _vs);
00509 IndexFor i_varsrem (varsrem, _vs);
00510 for( size_t i = 0; i < nrStates(); i++, ++i_vars, ++i_varsrem )
00511 if( (size_t)i_vars == varsState )
00512 result.set( i_varsrem, _p[i] );
00513
00514 return result;
00515 }
00516
00517
00518 template<typename T> TFactor<T> TFactor<T>::marginal(const VarSet &vars, bool normed) const {
00519 VarSet res_vars = vars & _vs;
00520
00521 TFactor<T> res( res_vars, 0.0 );
00522
00523 IndexFor i_res( res_vars, _vs );
00524 for( size_t i = 0; i < _p.size(); i++, ++i_res )
00525 res.set( i_res, res[i_res] + _p[i] );
00526
00527 if( normed )
00528 res.normalize( NORMPROB );
00529
00530 return res;
00531 }
00532
00533
00534 template<typename T> TFactor<T> TFactor<T>::maxMarginal(const VarSet &vars, bool normed) const {
00535 VarSet res_vars = vars & _vs;
00536
00537 TFactor<T> res( res_vars, 0.0 );
00538
00539 IndexFor i_res( res_vars, _vs );
00540 for( size_t i = 0; i < _p.size(); i++, ++i_res )
00541 if( _p[i] > res._p[i_res] )
00542 res.set( i_res, _p[i] );
00543
00544 if( normed )
00545 res.normalize( NORMPROB );
00546
00547 return res;
00548 }
00549
00550
00551 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
00552 DAI_DEBASSERT( _vs.contains( i ) );
00553 DAI_DEBASSERT( _vs.contains( j ) );
00554 DAI_DEBASSERT( i != j );
00555 VarSet ij(i, j);
00556
00557 T max = 0.0;
00558 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
00559 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
00560 if( alpha2 != alpha1 )
00561 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
00562 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
00563 if( beta2 != beta1 ) {
00564 size_t as = 1, bs = 1;
00565 if( i < j )
00566 bs = i.states();
00567 else
00568 as = j.states();
00569 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).max();
00570 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).max();
00571 T f = f1 * f2;
00572 if( f > max )
00573 max = f;
00574 }
00575
00576 return std::tanh( 0.25 * std::log( max ) );
00577 }
00578
00579
00581
00583 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& f) {
00584 os << "(" << f.vars() << ", (";
00585 for( size_t i = 0; i < f.nrStates(); i++ )
00586 os << (i == 0 ? "" : ", ") << f[i];
00587 os << "))";
00588 return os;
00589 }
00590
00591
00593
00596 template<typename T> T dist( const TFactor<T> &f, const TFactor<T> &g, ProbDistType dt ) {
00597 if( f.vars().empty() || g.vars().empty() )
00598 return -1;
00599 else {
00600 DAI_DEBASSERT( f.vars() == g.vars() );
00601 return dist( f.p(), g.p(), dt );
00602 }
00603 }
00604
00605
00607
00610 template<typename T> TFactor<T> max( const TFactor<T> &f, const TFactor<T> &g ) {
00611 DAI_ASSERT( f.vars() == g.vars() );
00612 return TFactor<T>( f.vars(), max( f.p(), g.p() ) );
00613 }
00614
00615
00617
00620 template<typename T> TFactor<T> min( const TFactor<T> &f, const TFactor<T> &g ) {
00621 DAI_ASSERT( f.vars() == g.vars() );
00622 return TFactor<T>( f.vars(), min( f.p(), g.p() ) );
00623 }
00624
00625
00627
00630 template<typename T> T MutualInfo(const TFactor<T> &f) {
00631 DAI_ASSERT( f.vars().size() == 2 );
00632 VarSet::const_iterator it = f.vars().begin();
00633 Var i = *it; it++; Var j = *it;
00634 TFactor<T> projection = f.marginal(i) * f.marginal(j);
00635 return dist( f.normalized(), projection, DISTKL );
00636 }
00637
00638
00640 typedef TFactor<Real> Factor;
00641
00642
00644
00647 Factor createFactorIsing( const Var &x, Real h );
00648
00649
00651
00655 Factor createFactorIsing( const Var &x1, const Var &x2, Real J );
00656
00657
00659
00664 Factor createFactorExpGauss( const VarSet &vs, Real beta );
00665
00666
00668
00672 Factor createFactorPotts( const Var &x1, const Var &x2, Real J );
00673
00674
00676
00679 Factor createFactorDelta( const Var &v, size_t state );
00680
00681
00682 }
00683
00684
00685 #endif