libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00012 00013 00014 #ifndef ___defined_libdai_bbp_h 00015 #define ___defined_libdai_bbp_h 00016 00017 00018 #include <vector> 00019 #include <utility> 00020 00021 #include <dai/prob.h> 00022 #include <dai/daialg.h> 00023 #include <dai/factorgraph.h> 00024 #include <dai/enum.h> 00025 #include <dai/bp_dual.h> 00026 00027 00028 namespace dai { 00029 00030 00032 00034 DAI_ENUM(BBPCostFunctionBase,CFN_GIBBS_B,CFN_GIBBS_B2,CFN_GIBBS_EXP,CFN_GIBBS_B_FACTOR,CFN_GIBBS_B2_FACTOR,CFN_GIBBS_EXP_FACTOR,CFN_VAR_ENT,CFN_FACTOR_ENT,CFN_BETHE_ENT); 00035 00036 00038 class BBPCostFunction : public BBPCostFunctionBase { 00039 public: 00041 BBPCostFunction() : BBPCostFunctionBase() {} 00042 00044 BBPCostFunction( const BBPCostFunctionBase &x ) : BBPCostFunctionBase(x) {} 00045 00047 bool needGibbsState() const; 00048 00050 Real evaluate( const InfAlg &ia, const std::vector<size_t> *stateP ) const; 00051 00053 BBPCostFunction& operator=( const BBPCostFunctionBase &x ) { 00054 BBPCostFunctionBase::operator=( x ); 00055 return *this; 00056 } 00057 }; 00058 00059 00061 00063 class BBP { 00064 private: 00066 00067 00068 BP_dual _bp_dual; 00070 const FactorGraph *_fg; 00072 const InfAlg *_ia; 00074 00076 00077 00078 std::vector<Prob> _adj_psi_V; 00080 std::vector<Prob> _adj_psi_F; 00082 std::vector<std::vector<Prob> > _adj_n; 00084 std::vector<std::vector<Prob> > _adj_m; 00086 std::vector<Prob> _adj_b_V; 00088 std::vector<Prob> _adj_b_F; 00090 00092 00093 00094 std::vector<Prob> _init_adj_psi_V; 00096 std::vector<Prob> _init_adj_psi_F; 00097 00099 std::vector<std::vector<Prob> > _adj_n_unnorm; 00101 std::vector<std::vector<Prob> > _adj_m_unnorm; 00103 std::vector<std::vector<Prob> > _new_adj_n; 00105 std::vector<std::vector<Prob> > _new_adj_m; 00107 std::vector<Prob> _adj_b_V_unnorm; 00109 std::vector<Prob> _adj_b_F_unnorm; 00110 00112 std::vector<std::vector<Prob > > _Tmsg; 00114 std::vector<std::vector<Prob > > _Umsg; 00116 std::vector<std::vector<std::vector<Prob > > > _Smsg; 00118 std::vector<std::vector<std::vector<Prob > > > _Rmsg; 00119 00121 size_t _iters; 00123 00125 00126 00127 typedef std::vector<size_t> _ind_t; 00129 std::vector<std::vector<_ind_t> > _indices; 00131 00133 void RegenerateInds(); 00135 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; } 00137 00139 00140 00141 void RegenerateT(); 00143 void RegenerateU(); 00145 void RegenerateS(); 00147 void RegenerateR(); 00149 void RegenerateInputs(); 00151 00153 void RegeneratePsiAdjoints(); 00155 00157 void RegenerateParMessageAdjoints(); 00159 00163 void RegenerateSeqMessageAdjoints(); 00165 void Regenerate(); 00167 00169 00170 00171 Prob & T(size_t i, size_t _I) { return _Tmsg[i][_I]; } 00173 const Prob & T(size_t i, size_t _I) const { return _Tmsg[i][_I]; } 00175 Prob & U(size_t I, size_t _i) { return _Umsg[I][_i]; } 00177 const Prob & U(size_t I, size_t _i) const { return _Umsg[I][_i]; } 00179 Prob & S(size_t i, size_t _I, size_t _j) { return _Smsg[i][_I][_j]; } 00181 const Prob & S(size_t i, size_t _I, size_t _j) const { return _Smsg[i][_I][_j]; } 00183 Prob & R(size_t I, size_t _i, size_t _J) { return _Rmsg[I][_i][_J]; } 00185 const Prob & R(size_t I, size_t _i, size_t _J) const { return _Rmsg[I][_i][_J]; } 00186 00188 Prob& adj_n(size_t i, size_t _I) { return _adj_n[i][_I]; } 00190 const Prob& adj_n(size_t i, size_t _I) const { return _adj_n[i][_I]; } 00192 Prob& adj_m(size_t i, size_t _I) { return _adj_m[i][_I]; } 00194 const Prob& adj_m(size_t i, size_t _I) const { return _adj_m[i][_I]; } 00196 00198 00199 00200 00203 void calcNewN( size_t i, size_t _I ); 00205 00208 void calcNewM( size_t i, size_t _I ); 00210 void calcUnnormMsgN( size_t i, size_t _I ); 00212 void calcUnnormMsgM( size_t i, size_t _I ); 00214 void upMsgN( size_t i, size_t _I ); 00216 void upMsgM( size_t i, size_t _I ); 00218 void doParUpdate(); 00220 00222 00223 00224 void incrSeqMsgM( size_t i, size_t _I, const Prob& p ); 00225 // DISABLED BECAUSE IT IS BUGGY: 00226 // void updateSeqMsgM( size_t i, size_t _I ); 00228 void setSeqMsgM( size_t i, size_t _I, const Prob &p ); 00230 void sendSeqMsgN( size_t i, size_t _I, const Prob &f ); 00232 void sendSeqMsgM( size_t i, size_t _I ); 00234 00236 00238 Prob unnormAdjoint( const Prob &w, Real Z_w, const Prob &adj_w ); 00239 00241 Real getUnMsgMag(); 00243 void getMsgMags( Real &s, Real &new_s ); 00245 void getArgmaxMsgM( size_t &i, size_t &_I, Real &mag ); 00247 Real getMaxMsgM(); 00248 00250 Real getTotalMsgM(); 00252 Real getTotalNewMsgM(); 00254 Real getTotalMsgN(); 00255 00257 std::vector<Prob> getZeroAdjF( const FactorGraph &fg ); 00259 std::vector<Prob> getZeroAdjV( const FactorGraph &fg ); 00260 00261 public: 00263 00264 00265 00268 BBP( const InfAlg *ia, const PropertySet &opts ) : _bp_dual(ia), _fg(&(ia->fg())), _ia(ia) { 00269 props.set(opts); 00270 } 00272 00274 00275 00276 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F, const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F ) { 00277 _adj_b_V = adj_b_V; 00278 _adj_b_F = adj_b_F; 00279 _init_adj_psi_V = adj_psi_V; 00280 _init_adj_psi_F = adj_psi_F; 00281 Regenerate(); 00282 } 00283 00285 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F ) { 00286 init( adj_b_V, adj_b_F, getZeroAdjV(*_fg), getZeroAdjF(*_fg) ); 00287 } 00288 00290 void init_V( const std::vector<Prob> &adj_b_V ) { 00291 init( adj_b_V, getZeroAdjF(*_fg) ); 00292 } 00293 00295 void init_F( const std::vector<Prob> &adj_b_F ) { 00296 init( getZeroAdjV(*_fg), adj_b_F ); 00297 } 00298 00300 00304 void initCostFnAdj( const BBPCostFunction &cfn, const std::vector<size_t> *stateP ); 00306 00308 00309 00310 void run(); 00312 00314 00315 00316 Prob& adj_psi_V(size_t i) { return _adj_psi_V[i]; } 00318 const Prob& adj_psi_V(size_t i) const { return _adj_psi_V[i]; } 00320 Prob& adj_psi_F(size_t I) { return _adj_psi_F[I]; } 00322 const Prob& adj_psi_F(size_t I) const { return _adj_psi_F[I]; } 00324 Prob& adj_b_V(size_t i) { return _adj_b_V[i]; } 00326 const Prob& adj_b_V(size_t i) const { return _adj_b_V[i]; } 00328 Prob& adj_b_F(size_t I) { return _adj_b_F[I]; } 00330 const Prob& adj_b_F(size_t I) const { return _adj_b_F[I]; } 00332 size_t Iterations() { return _iters; } 00334 00335 public: 00337 /* PROPERTIES(props,BBP) { 00345 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR); 00346 00348 size_t verbose = 0; 00349 00351 size_t maxiter; 00352 00355 Real tol; 00356 00358 Real damping; 00359 00361 UpdateType updates; 00362 00363 // DISABLED BECAUSE IT IS BUGGY: 00364 // bool clean_updates; 00365 } 00366 */ 00367 /* {{{ GENERATED CODE: DO NOT EDIT. Created by 00368 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp 00369 */ 00370 struct Properties { 00372 00379 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR); 00381 size_t verbose; 00383 size_t maxiter; 00385 00387 Real tol; 00389 Real damping; 00391 UpdateType updates; 00392 00394 00397 void set(const PropertySet &opts); 00399 PropertySet get() const; 00401 std::string toString() const; 00402 } props; 00403 /* }}} END OF GENERATED CODE */ 00404 }; 00405 00406 00408 00416 Real numericBBPTest( const InfAlg &bp, const std::vector<size_t> *state, const PropertySet &bbp_props, const BBPCostFunction &cfn, Real h ); 00417 00418 00419 } // end of namespace dai 00420 00421 00422 #endif