00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00015
00016
00017 #ifndef ___defined_libdai_bbp_h
00018 #define ___defined_libdai_bbp_h
00019
00020
00021 #include <vector>
00022 #include <utility>
00023
00024 #include <dai/prob.h>
00025 #include <dai/daialg.h>
00026 #include <dai/factorgraph.h>
00027 #include <dai/enum.h>
00028 #include <dai/bp_dual.h>
00029
00030
00031 namespace dai {
00032
00033
00035
00037 DAI_ENUM(BBPCostFunctionBase,CFN_GIBBS_B,CFN_GIBBS_B2,CFN_GIBBS_EXP,CFN_GIBBS_B_FACTOR,CFN_GIBBS_B2_FACTOR,CFN_GIBBS_EXP_FACTOR,CFN_VAR_ENT,CFN_FACTOR_ENT,CFN_BETHE_ENT);
00038
00039
00041 class BBPCostFunction : public BBPCostFunctionBase {
00042 public:
00044 BBPCostFunction() : BBPCostFunctionBase() {}
00045
00047 BBPCostFunction( const BBPCostFunctionBase &x ) : BBPCostFunctionBase(x) {}
00048
00050 bool needGibbsState() const;
00051
00053 Real evaluate( const InfAlg &ia, const std::vector<size_t> *stateP ) const;
00054
00056 BBPCostFunction& operator=( const BBPCostFunctionBase &x ) {
00057 if( this != &x ) {
00058 (BBPCostFunctionBase)*this = x;
00059 }
00060 return *this;
00061 }
00062 };
00063
00064
00066
00068 class BBP {
00069 private:
00071
00072
00073 BP_dual _bp_dual;
00075 const FactorGraph *_fg;
00077 const InfAlg *_ia;
00079
00081
00082
00083 std::vector<Prob> _adj_psi_V;
00085 std::vector<Prob> _adj_psi_F;
00087 std::vector<std::vector<Prob> > _adj_n;
00089 std::vector<std::vector<Prob> > _adj_m;
00091 std::vector<Prob> _adj_b_V;
00093 std::vector<Prob> _adj_b_F;
00095
00097
00098
00099 std::vector<Prob> _init_adj_psi_V;
00101 std::vector<Prob> _init_adj_psi_F;
00102
00104 std::vector<std::vector<Prob> > _adj_n_unnorm;
00106 std::vector<std::vector<Prob> > _adj_m_unnorm;
00108 std::vector<std::vector<Prob> > _new_adj_n;
00110 std::vector<std::vector<Prob> > _new_adj_m;
00112 std::vector<Prob> _adj_b_V_unnorm;
00114 std::vector<Prob> _adj_b_F_unnorm;
00115
00117 std::vector<std::vector<Prob > > _Tmsg;
00119 std::vector<std::vector<Prob > > _Umsg;
00121 std::vector<std::vector<std::vector<Prob > > > _Smsg;
00123 std::vector<std::vector<std::vector<Prob > > > _Rmsg;
00124
00126 size_t _iters;
00128
00130
00131
00132 typedef std::vector<size_t> _ind_t;
00134 std::vector<std::vector<_ind_t> > _indices;
00136
00138 void RegenerateInds();
00140 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
00142
00144
00145
00146 void RegenerateT();
00148 void RegenerateU();
00150 void RegenerateS();
00152 void RegenerateR();
00154 void RegenerateInputs();
00156
00158 void RegeneratePsiAdjoints();
00160
00162 void RegenerateParMessageAdjoints();
00164
00168 void RegenerateSeqMessageAdjoints();
00170 void Regenerate();
00172
00174
00175
00176 Prob & T(size_t i, size_t _I) { return _Tmsg[i][_I]; }
00178 const Prob & T(size_t i, size_t _I) const { return _Tmsg[i][_I]; }
00180 Prob & U(size_t I, size_t _i) { return _Umsg[I][_i]; }
00182 const Prob & U(size_t I, size_t _i) const { return _Umsg[I][_i]; }
00184 Prob & S(size_t i, size_t _I, size_t _j) { return _Smsg[i][_I][_j]; }
00186 const Prob & S(size_t i, size_t _I, size_t _j) const { return _Smsg[i][_I][_j]; }
00188 Prob & R(size_t I, size_t _i, size_t _J) { return _Rmsg[I][_i][_J]; }
00190 const Prob & R(size_t I, size_t _i, size_t _J) const { return _Rmsg[I][_i][_J]; }
00191
00193 Prob& adj_n(size_t i, size_t _I) { return _adj_n[i][_I]; }
00195 const Prob& adj_n(size_t i, size_t _I) const { return _adj_n[i][_I]; }
00197 Prob& adj_m(size_t i, size_t _I) { return _adj_m[i][_I]; }
00199 const Prob& adj_m(size_t i, size_t _I) const { return _adj_m[i][_I]; }
00201
00203
00204
00205
00208 void calcNewN( size_t i, size_t _I );
00210
00213 void calcNewM( size_t i, size_t _I );
00215 void calcUnnormMsgN( size_t i, size_t _I );
00217 void calcUnnormMsgM( size_t i, size_t _I );
00219 void upMsgN( size_t i, size_t _I );
00221 void upMsgM( size_t i, size_t _I );
00223 void doParUpdate();
00225
00227
00228
00229 void incrSeqMsgM( size_t i, size_t _I, const Prob& p );
00230
00231
00233 void setSeqMsgM( size_t i, size_t _I, const Prob &p );
00235 void sendSeqMsgN( size_t i, size_t _I, const Prob &f );
00237 void sendSeqMsgM( size_t i, size_t _I );
00239
00241
00243 Prob unnormAdjoint( const Prob &w, Real Z_w, const Prob &adj_w );
00244
00246 Real getUnMsgMag();
00248 void getMsgMags( Real &s, Real &new_s );
00250 void getArgmaxMsgM( size_t &i, size_t &_I, Real &mag );
00252 Real getMaxMsgM();
00253
00255 Real getTotalMsgM();
00257 Real getTotalNewMsgM();
00259 Real getTotalMsgN();
00260
00262 std::vector<Prob> getZeroAdjF( const FactorGraph &fg );
00264 std::vector<Prob> getZeroAdjV( const FactorGraph &fg );
00265
00266 public:
00268
00269
00270
00273 BBP( const InfAlg *ia, const PropertySet &opts ) : _bp_dual(ia), _fg(&(ia->fg())), _ia(ia) {
00274 props.set(opts);
00275 }
00277
00279
00280
00281 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F, const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F ) {
00282 _adj_b_V = adj_b_V;
00283 _adj_b_F = adj_b_F;
00284 _init_adj_psi_V = adj_psi_V;
00285 _init_adj_psi_F = adj_psi_F;
00286 Regenerate();
00287 }
00288
00290 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F ) {
00291 init( adj_b_V, adj_b_F, getZeroAdjV(*_fg), getZeroAdjF(*_fg) );
00292 }
00293
00295 void init_V( const std::vector<Prob> &adj_b_V ) {
00296 init( adj_b_V, getZeroAdjF(*_fg) );
00297 }
00298
00300 void init_F( const std::vector<Prob> &adj_b_F ) {
00301 init( getZeroAdjV(*_fg), adj_b_F );
00302 }
00303
00305
00309 void initCostFnAdj( const BBPCostFunction &cfn, const std::vector<size_t> *stateP );
00311
00313
00314
00315 void run();
00317
00319
00320
00321 Prob& adj_psi_V(size_t i) { return _adj_psi_V[i]; }
00323 const Prob& adj_psi_V(size_t i) const { return _adj_psi_V[i]; }
00325 Prob& adj_psi_F(size_t I) { return _adj_psi_F[I]; }
00327 const Prob& adj_psi_F(size_t I) const { return _adj_psi_F[I]; }
00329 Prob& adj_b_V(size_t i) { return _adj_b_V[i]; }
00331 const Prob& adj_b_V(size_t i) const { return _adj_b_V[i]; }
00333 Prob& adj_b_F(size_t I) { return _adj_b_F[I]; }
00335 const Prob& adj_b_F(size_t I) const { return _adj_b_F[I]; }
00337 size_t Iterations() { return _iters; }
00339
00340 public:
00342
00350
00351
00353
00354
00356
00357
00360
00361
00363
00364
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375 struct Properties {
00377
00384 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
00386 size_t verbose;
00388 size_t maxiter;
00390
00392 Real tol;
00394 Real damping;
00396 UpdateType updates;
00397
00399
00402 void set(const PropertySet &opts);
00404 PropertySet get() const;
00406 std::string toString() const;
00407 } props;
00408
00409 };
00410
00411
00413
00421 Real numericBBPTest( const InfAlg &bp, const std::vector<size_t> *state, const PropertySet &bbp_props, const BBPCostFunction &cfn, Real h );
00422
00423
00424 }
00425
00426
00427 #endif