00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00014
00015
00016 #ifndef __defined_libdai_factorgraph_h
00017 #define __defined_libdai_factorgraph_h
00018
00019
00020 #include <iostream>
00021 #include <map>
00022 #include <dai/bipgraph.h>
00023 #include <dai/graph.h>
00024 #include <dai/factor.h>
00025
00026
00027 namespace dai {
00028
00029
00031
00066 class FactorGraph {
00067 public:
00069 typedef BipartiteGraph::Neighbor Neighbor;
00070
00072 typedef BipartiteGraph::Neighbors Neighbors;
00073
00075 typedef BipartiteGraph::Edge Edge;
00076
00078
00080 typedef std::vector<Factor>::iterator iterator;
00081
00083
00085 typedef std::vector<Factor>::const_iterator const_iterator;
00086
00087 private:
00089 BipartiteGraph _G;
00091 std::vector<Var> _vars;
00093 std::vector<Factor> _factors;
00095 std::map<size_t,Factor> _backup;
00096
00097 public:
00099
00100
00101 FactorGraph() : _G(), _vars(), _factors(), _backup() {}
00102
00104 FactorGraph( const std::vector<Factor>& P );
00105
00107
00111 template<typename FactorInputIterator, typename VarInputIterator>
00112 FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint = 0, size_t nrVarHint = 0 );
00113
00115 virtual ~FactorGraph() {}
00116
00118 virtual FactorGraph* clone() const { return new FactorGraph(*this); }
00120
00122
00123
00124 const Var& var( size_t i ) const {
00125 DAI_DEBASSERT( i < nrVars() );
00126 return _vars[i];
00127 }
00128
00130 const std::vector<Var>& vars() const { return _vars; }
00131
00133
00136 Factor& factor( size_t I ) {
00137 DAI_DEBASSERT( I < nrFactors() );
00138 return _factors[I];
00139 }
00140
00142 const Factor& factor( size_t I ) const {
00143 DAI_DEBASSERT( I < nrFactors() );
00144 return _factors[I];
00145 }
00147 const std::vector<Factor>& factors() const { return _factors; }
00148
00150 const Neighbors& nbV( size_t i ) const { return _G.nb1(i); }
00152 const Neighbors& nbF( size_t I ) const { return _G.nb2(I); }
00154 const Neighbor& nbV( size_t i, size_t _I ) const { return _G.nb1(i)[_I]; }
00156 const Neighbor& nbF( size_t I, size_t _i ) const { return _G.nb2(I)[_i]; }
00158
00160
00161
00162
00164 iterator begin() { return _factors.begin(); }
00166
00168 const_iterator begin() const { return _factors.begin(); }
00170
00172 iterator end() { return _factors.end(); }
00174
00176 const_iterator end() const { return _factors.end(); }
00178
00180
00181
00182 const BipartiteGraph& bipGraph() const { return _G; }
00184 size_t nrVars() const { return vars().size(); }
00186 size_t nrFactors() const { return factors().size(); }
00188
00190 size_t nrEdges() const { return _G.nrEdges(); }
00191
00193
00196 size_t findVar( const Var& n ) const {
00197 size_t i = find( vars().begin(), vars().end(), n ) - vars().begin();
00198 if( i == nrVars() )
00199 DAI_THROW(OBJECT_NOT_FOUND);
00200 return i;
00201 }
00202
00204
00207 SmallSet<size_t> findVars( const VarSet& ns ) const {
00208 SmallSet<size_t> result;
00209 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
00210 result.insert( findVar( *n ) );
00211 return result;
00212 }
00213
00215
00218 size_t findFactor( const VarSet& ns ) const {
00219 size_t I;
00220 for( I = 0; I < nrFactors(); I++ )
00221 if( factor(I).vars() == ns )
00222 break;
00223 if( I == nrFactors() )
00224 DAI_THROW(OBJECT_NOT_FOUND);
00225 return I;
00226 }
00227
00229 VarSet Delta( size_t i ) const;
00230
00232 VarSet Delta( const VarSet& vs ) const;
00233
00235 VarSet delta( size_t i ) const {
00236 return( Delta( i ) / var( i ) );
00237 }
00238
00240 VarSet delta( const VarSet& vs ) const {
00241 return Delta( vs ) / vs;
00242 }
00243
00245 bool isConnected() const { return _G.isConnected(); }
00246
00248 bool isTree() const { return _G.isTree(); }
00249
00251 bool isPairwise() const;
00252
00254 bool isBinary() const;
00255
00257
00260 GraphAL MarkovGraph() const;
00261
00263
00266 std::vector<VarSet> maximalFactorDomains() const;
00267
00269
00271 std::vector<VarSet> cliques() const {
00272 return maximalFactorDomains();
00273 }
00275
00277
00278
00279 virtual void setFactor( size_t I, const Factor& newFactor, bool backup = false ) {
00280 DAI_ASSERT( newFactor.vars() == factor(I).vars() );
00281 if( backup )
00282 backupFactor( I );
00283 _factors[I] = newFactor;
00284 }
00285
00287 virtual void setFactors( const std::map<size_t, Factor>& facs, bool backup = false ) {
00288 for( std::map<size_t, Factor>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ ) {
00289 if( backup )
00290 backupFactor( fac->first );
00291 setFactor( fac->first, fac->second );
00292 }
00293 }
00294
00296
00298 void backupFactor( size_t I );
00299
00301
00303 void restoreFactor( size_t I );
00304
00306
00308 virtual void backupFactors( const std::set<size_t>& facs );
00309
00311 virtual void restoreFactors();
00312
00314
00316 void backupFactors( const VarSet& ns );
00317
00319 void restoreFactors( const VarSet& ns );
00321
00323
00324
00325 FactorGraph maximalFactors() const;
00326
00328
00331 FactorGraph clamped( size_t i, size_t x ) const;
00333
00335
00336
00337
00339 virtual void clamp( size_t i, size_t x, bool backup = false );
00340
00342
00344 void clampVar( size_t i, const std::vector<size_t>& xis, bool backup = false );
00345
00347
00349 void clampFactor( size_t I, const std::vector<size_t>& xIs, bool backup = false );
00350
00352
00354 virtual void makeCavity( size_t i, bool backup = false );
00356
00358
00359
00360
00364 virtual void ReadFromFile( const char *filename );
00365
00367
00370 virtual void WriteToFile( const char *filename, size_t precision=15 ) const;
00371
00373
00375 friend std::ostream& operator<< (std::ostream& os, const FactorGraph& fg );
00376
00378
00381 friend std::istream& operator>> (std::istream& is, FactorGraph& fg );
00382
00384 virtual void printDot( std::ostream& os ) const;
00386
00387 private:
00389 void constructGraph( size_t nrEdges );
00390 };
00391
00392
00393 template<typename FactorInputIterator, typename VarInputIterator>
00394 FactorGraph::FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint, size_t nrVarHint ) : _G(), _backup() {
00395
00396 size_t nrEdges = 0;
00397 _factors.reserve( nrFacHint );
00398 for( FactorInputIterator p2 = facBegin; p2 != facEnd; ++p2 ) {
00399 _factors.push_back( *p2 );
00400 nrEdges += p2->vars().size();
00401 }
00402
00403
00404 _vars.reserve( nrVarHint );
00405 for( VarInputIterator p1 = varBegin; p1 != varEnd; ++p1 )
00406 _vars.push_back( *p1 );
00407
00408
00409 constructGraph( nrEdges );
00410 }
00411
00412
00419 }
00420
00421
00422 #endif