00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00014
00015
00016 #ifndef __defined_libdai_factorgraph_h
00017 #define __defined_libdai_factorgraph_h
00018
00019
00020 #include <iostream>
00021 #include <map>
00022 #include <dai/bipgraph.h>
00023 #include <dai/factor.h>
00024
00025
00026 namespace dai {
00027
00028
00030
00065 class FactorGraph {
00066 public:
00068 BipartiteGraph G;
00069
00071 typedef BipartiteGraph::Neighbor Neighbor;
00072
00074 typedef BipartiteGraph::Neighbors Neighbors;
00075
00077 typedef BipartiteGraph::Edge Edge;
00078
00080 typedef std::vector<Factor>::iterator iterator;
00081
00083 typedef std::vector<Factor>::const_iterator const_iterator;
00084
00085
00086 private:
00088 std::vector<Var> _vars;
00090 std::vector<Factor> _factors;
00092 std::map<size_t,Factor> _backup;
00093
00094 public:
00096
00097
00098 FactorGraph() : G(), _vars(), _factors(), _backup() {}
00099
00101 FactorGraph( const std::vector<Factor> &P );
00102
00104
00108 template<typename FactorInputIterator, typename VarInputIterator>
00109 FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint = 0, size_t nr_var_hint = 0 );
00110
00112 virtual ~FactorGraph() {}
00113
00115 virtual FactorGraph* clone() const { return new FactorGraph(); }
00117
00119
00120
00121 const Var & var(size_t i) const { return _vars[i]; }
00123 const std::vector<Var> & vars() const { return _vars; }
00124
00126 Factor & factor(size_t I) { return _factors[I]; }
00128 const Factor & factor(size_t I) const { return _factors[I]; }
00130 const std::vector<Factor> & factors() const { return _factors; }
00131
00133 const Neighbors & nbV( size_t i ) const { return G.nb1(i); }
00135 const Neighbors & nbF( size_t I ) const { return G.nb2(I); }
00137 const Neighbor & nbV( size_t i, size_t _I ) const { return G.nb1(i)[_I]; }
00139 const Neighbor & nbF( size_t I, size_t _i ) const { return G.nb2(I)[_i]; }
00141
00143
00144
00145 iterator begin() { return _factors.begin(); }
00147 const_iterator begin() const { return _factors.begin(); }
00149 iterator end() { return _factors.end(); }
00151 const_iterator end() const { return _factors.end(); }
00153
00155
00156
00157 size_t nrVars() const { return vars().size(); }
00159 size_t nrFactors() const { return factors().size(); }
00161
00163 size_t nrEdges() const { return G.nrEdges(); }
00164
00166
00169 size_t findVar( const Var &n ) const {
00170 size_t i = find( vars().begin(), vars().end(), n ) - vars().begin();
00171 if( i == nrVars() )
00172 DAI_THROW(OBJECT_NOT_FOUND);
00173 return i;
00174 }
00175
00177
00180 std::set<size_t> findVars( VarSet &ns ) const {
00181 std::set<size_t> indexes;
00182 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
00183 indexes.insert( findVar( *n ) );
00184 return indexes;
00185 }
00186
00188
00191 size_t findFactor( const VarSet &ns ) const {
00192 size_t I;
00193 for( I = 0; I < nrFactors(); I++ )
00194 if( factor(I).vars() == ns )
00195 break;
00196 if( I == nrFactors() )
00197 DAI_THROW(OBJECT_NOT_FOUND);
00198 return I;
00199 }
00200
00202 VarSet Delta( size_t i ) const;
00203
00205 VarSet Delta( const VarSet &vs ) const;
00206
00208 VarSet delta( size_t i ) const;
00209
00211 VarSet delta( const VarSet &vs ) const {
00212 return Delta( vs ) / vs;
00213 }
00214
00216 bool isConnected() const { return G.isConnected(); }
00217
00219 bool isTree() const { return G.isTree(); }
00220
00222 bool isPairwise() const;
00223
00225 bool isBinary() const;
00226
00228 std::vector<VarSet> Cliques() const;
00230
00232
00233
00234 virtual void setFactor( size_t I, const Factor &newFactor, bool backup = false ) {
00235 DAI_ASSERT( newFactor.vars() == factor(I).vars() );
00236 if( backup )
00237 backupFactor( I );
00238 _factors[I] = newFactor;
00239 }
00240
00242 virtual void setFactors( const std::map<size_t, Factor> & facs, bool backup = false ) {
00243 for( std::map<size_t, Factor>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ ) {
00244 if( backup )
00245 backupFactor( fac->first );
00246 setFactor( fac->first, fac->second );
00247 }
00248 }
00249
00251
00253 void backupFactor( size_t I );
00254
00256 void restoreFactor( size_t I );
00257
00259
00261 virtual void backupFactors( const std::set<size_t> & facs );
00262
00264 virtual void restoreFactors();
00265
00267
00269 void backupFactors( const VarSet &ns );
00270
00272 void restoreFactors( const VarSet &ns );
00274
00276
00277
00278 FactorGraph maximalFactors() const;
00279
00281
00284 FactorGraph clamped( size_t i, size_t x ) const;
00286
00288
00289
00290
00292 virtual void clamp( size_t i, size_t x, bool backup = false );
00293
00295
00297 void clampVar( size_t i, const std::vector<size_t> &xis, bool backup = false );
00298
00300
00302 void clampFactor( size_t I, const std::vector<size_t> &xIs, bool backup = false );
00303
00305
00307 virtual void makeCavity( size_t i, bool backup = false );
00309
00311
00312
00313
00317 void ReadFromFile( const char *filename );
00318
00320
00323 void WriteToFile( const char *filename, size_t precision=15 ) const;
00324
00326
00328 friend std::ostream& operator<< (std::ostream &os, const FactorGraph &fg );
00329
00331
00334 friend std::istream& operator>> (std::istream &is, FactorGraph &fg );
00335
00337 void printDot( std::ostream& os ) const;
00339
00340 private:
00342 void constructGraph( size_t nrEdges );
00343 };
00344
00345
00346 template<typename FactorInputIterator, typename VarInputIterator>
00347 FactorGraph::FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint, size_t nr_var_hint ) : G(), _backup() {
00348
00349 size_t nrEdges = 0;
00350 _factors.reserve( nr_fact_hint );
00351 for( FactorInputIterator p2 = fact_begin; p2 != fact_end; ++p2 ) {
00352 _factors.push_back( *p2 );
00353 nrEdges += p2->vars().size();
00354 }
00355
00356
00357 _vars.reserve( nr_var_hint );
00358 for( VarInputIterator p1 = var_begin; p1 != var_end; ++p1 )
00359 _vars.push_back( *p1 );
00360
00361
00362 constructGraph( nrEdges );
00363 }
00364
00365
00379 }
00380
00381
00382 #endif