libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00012 00013 00014 #ifndef __defined_libdai_bp_h 00015 #define __defined_libdai_bp_h 00016 00017 00018 #include <string> 00019 #include <dai/daialg.h> 00020 #include <dai/factorgraph.h> 00021 #include <dai/properties.h> 00022 #include <dai/enum.h> 00023 00024 00025 namespace dai { 00026 00027 00029 00059 class BP : public DAIAlgFG { 00060 protected: 00062 typedef std::vector<size_t> ind_t; 00064 struct EdgeProp { 00066 ind_t index; 00068 Prob message; 00070 Prob newMessage; 00072 Real residual; 00073 }; 00075 std::vector<std::vector<EdgeProp> > _edges; 00077 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType; 00079 std::vector<std::vector<LutType::iterator> > _edge2lut; 00081 LutType _lut; 00083 Real _maxdiff; 00085 size_t _iters; 00087 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages; 00089 std::vector<Factor> _oldBeliefsV; 00091 std::vector<Factor> _oldBeliefsF; 00093 std::vector<Edge> _updateSeq; 00094 00095 public: 00097 struct Properties { 00099 00105 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL); 00106 00108 00112 DAI_ENUM(InfType,SUMPROD,MAXPROD); 00113 00115 size_t verbose; 00116 00118 size_t maxiter; 00119 00121 double maxtime; 00122 00124 Real tol; 00125 00127 bool logdomain; 00128 00130 Real damping; 00131 00133 UpdateType updates; 00134 00136 InfType inference; 00137 } props; 00138 00140 bool recordSentMessages; 00141 00142 public: 00144 00145 00146 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {} 00147 00149 00152 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) { 00153 setProperties( opts ); 00154 construct(); 00155 } 00156 00158 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut), _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages), _oldBeliefsV(x._oldBeliefsV), _oldBeliefsF(x._oldBeliefsF), _updateSeq(x._updateSeq), props(x.props), recordSentMessages(x.recordSentMessages) { 00159 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l ) 00160 _edge2lut[l->second.first][l->second.second] = l; 00161 } 00162 00164 BP& operator=( const BP &x ) { 00165 if( this != &x ) { 00166 DAIAlgFG::operator=( x ); 00167 _edges = x._edges; 00168 _lut = x._lut; 00169 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l ) 00170 _edge2lut[l->second.first][l->second.second] = l; 00171 _maxdiff = x._maxdiff; 00172 _iters = x._iters; 00173 _sentMessages = x._sentMessages; 00174 _oldBeliefsV = x._oldBeliefsV; 00175 _oldBeliefsF = x._oldBeliefsF; 00176 _updateSeq = x._updateSeq; 00177 props = x.props; 00178 recordSentMessages = x.recordSentMessages; 00179 } 00180 return *this; 00181 } 00183 00185 00186 virtual BP* clone() const { return new BP(*this); } 00187 virtual BP* construct( const FactorGraph &fg, const PropertySet &opts ) const { return new BP( fg, opts ); } 00188 virtual std::string name() const { return "BP"; } 00189 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); } 00190 virtual Factor belief( const VarSet &vs ) const; 00191 virtual Factor beliefV( size_t i ) const; 00192 virtual Factor beliefF( size_t I ) const; 00193 virtual std::vector<Factor> beliefs() const; 00194 virtual Real logZ() const; 00197 std::vector<std::size_t> findMaximum() const { return dai::findMaximum( *this ); } 00198 virtual void init(); 00199 virtual void init( const VarSet &ns ); 00200 virtual Real run(); 00201 virtual Real maxDiff() const { return _maxdiff; } 00202 virtual size_t Iterations() const { return _iters; } 00203 virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; } 00204 virtual void setProperties( const PropertySet &opts ); 00205 virtual PropertySet getProperties() const; 00206 virtual std::string printProperties() const; 00208 00210 00211 00212 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const { 00213 return _sentMessages; 00214 } 00215 00217 void clearSentMessages() { _sentMessages.clear(); } 00219 00220 protected: 00222 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; } 00224 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; } 00226 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; } 00228 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; } 00230 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; } 00232 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; } 00234 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; } 00236 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; } 00237 00239 00242 virtual Prob calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const; 00244 virtual void calcNewMessage( size_t i, size_t _I ); 00246 void updateMessage( size_t i, size_t _I ); 00248 void updateResidual( size_t i, size_t _I, Real r ); 00250 void findMaxResidual( size_t &i, size_t &_I ); 00252 virtual void calcBeliefV( size_t i, Prob &p ) const; 00254 virtual void calcBeliefF( size_t I, Prob &p ) const { 00255 p = calcIncomingMessageProduct( I, false, 0 ); 00256 } 00257 00259 virtual void construct(); 00260 }; 00261 00262 00263 } // end of namespace dai 00264 00265 00266 #endif