libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00011 00012 00013 #ifndef __defined_libdai_factorgraph_h 00014 #define __defined_libdai_factorgraph_h 00015 00016 00017 #include <iostream> 00018 #include <map> 00019 #include <dai/bipgraph.h> 00020 #include <dai/graph.h> 00021 #include <dai/factor.h> 00022 00023 00024 namespace dai { 00025 00026 00028 00065 class FactorGraph { 00066 private: 00068 BipartiteGraph _G; 00070 std::vector<Var> _vars; 00072 std::vector<Factor> _factors; 00074 std::map<size_t,Factor> _backup; 00075 00076 public: 00078 00079 00080 FactorGraph() : _G(), _vars(), _factors(), _backup() {} 00081 00083 FactorGraph( const std::vector<Factor>& P ); 00084 00086 00090 template<typename FactorInputIterator, typename VarInputIterator> 00091 FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint = 0, size_t nrVarHint = 0 ); 00092 00094 virtual ~FactorGraph() {} 00095 00097 virtual FactorGraph* clone() const { return new FactorGraph(*this); } 00099 00101 00102 00103 const Var& var( size_t i ) const { 00104 DAI_DEBASSERT( i < nrVars() ); 00105 return _vars[i]; 00106 } 00107 00109 const std::vector<Var>& vars() const { return _vars; } 00110 00112 const Factor& factor( size_t I ) const { 00113 DAI_DEBASSERT( I < nrFactors() ); 00114 return _factors[I]; 00115 } 00117 const std::vector<Factor>& factors() const { return _factors; } 00118 00120 const Neighbors& nbV( size_t i ) const { return _G.nb1(i); } 00122 const Neighbors& nbF( size_t I ) const { return _G.nb2(I); } 00124 const Neighbor& nbV( size_t i, size_t _I ) const { return _G.nb1(i)[_I]; } 00126 const Neighbor& nbF( size_t I, size_t _i ) const { return _G.nb2(I)[_i]; } 00128 00130 00131 00132 const BipartiteGraph& bipGraph() const { return _G; } 00134 size_t nrVars() const { return vars().size(); } 00136 size_t nrFactors() const { return factors().size(); } 00138 00140 size_t nrEdges() const { return _G.nrEdges(); } 00141 00143 00146 size_t findVar( const Var& n ) const { 00147 size_t i = find( vars().begin(), vars().end(), n ) - vars().begin(); 00148 if( i == nrVars() ) 00149 DAI_THROW(OBJECT_NOT_FOUND); 00150 return i; 00151 } 00152 00154 00157 SmallSet<size_t> findVars( const VarSet& ns ) const { 00158 SmallSet<size_t> result; 00159 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) 00160 result.insert( findVar( *n ) ); 00161 return result; 00162 } 00163 00165 00168 size_t findFactor( const VarSet& ns ) const { 00169 size_t I; 00170 for( I = 0; I < nrFactors(); I++ ) 00171 if( factor(I).vars() == ns ) 00172 break; 00173 if( I == nrFactors() ) 00174 DAI_THROW(OBJECT_NOT_FOUND); 00175 return I; 00176 } 00177 00179 VarSet Delta( size_t i ) const; 00180 00182 VarSet Delta( const VarSet& vs ) const; 00183 00185 VarSet delta( size_t i ) const { 00186 return( Delta( i ) / var( i ) ); 00187 } 00188 00190 VarSet delta( const VarSet& vs ) const { 00191 return Delta( vs ) / vs; 00192 } 00193 00195 bool isConnected() const { return _G.isConnected(); } 00196 00198 bool isTree() const { return _G.isTree(); } 00199 00201 bool isPairwise() const; 00202 00204 bool isBinary() const; 00205 00207 00210 GraphAL MarkovGraph() const; 00211 00213 00216 bool isMaximal( size_t I ) const; 00217 00219 00222 size_t maximalFactor( size_t I ) const; 00223 00225 00228 std::vector<VarSet> maximalFactorDomains() const; 00229 00231 Real logScore( const std::vector<size_t>& statevec ) const; 00233 00235 00236 00237 virtual void setFactor( size_t I, const Factor& newFactor, bool backup = false ) { 00238 DAI_ASSERT( newFactor.vars() == factor(I).vars() ); 00239 if( backup ) 00240 backupFactor( I ); 00241 _factors[I] = newFactor; 00242 } 00243 00245 virtual void setFactors( const std::map<size_t, Factor>& facs, bool backup = false ) { 00246 for( std::map<size_t, Factor>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ ) { 00247 if( backup ) 00248 backupFactor( fac->first ); 00249 setFactor( fac->first, fac->second ); 00250 } 00251 } 00252 00254 00256 void backupFactor( size_t I ); 00257 00259 00261 void restoreFactor( size_t I ); 00262 00264 00266 virtual void backupFactors( const std::set<size_t>& facs ); 00267 00269 virtual void restoreFactors(); 00270 00272 00274 void backupFactors( const VarSet& ns ); 00275 00277 void restoreFactors( const VarSet& ns ); 00279 00281 00282 00283 FactorGraph maximalFactors() const; 00284 00286 00289 FactorGraph clamped( size_t i, size_t x ) const; 00291 00293 00294 00295 00297 virtual void clamp( size_t i, size_t x, bool backup = false ); 00298 00300 00302 void clampVar( size_t i, const std::vector<size_t>& xis, bool backup = false ); 00303 00305 00307 void clampFactor( size_t I, const std::vector<size_t>& xIs, bool backup = false ); 00308 00310 00312 virtual void makeCavity( size_t i, bool backup = false ); 00314 00316 00317 00318 00322 virtual void ReadFromFile( const char *filename ); 00323 00325 00328 virtual void WriteToFile( const char *filename, size_t precision=15 ) const; 00329 00331 00333 friend std::ostream& operator<< (std::ostream& os, const FactorGraph& fg ); 00334 00336 00339 friend std::istream& operator>> (std::istream& is, FactorGraph& fg ); 00340 00342 virtual void printDot( std::ostream& os ) const; 00344 00345 private: 00347 void constructGraph( size_t nrEdges ); 00348 }; 00349 00350 00351 template<typename FactorInputIterator, typename VarInputIterator> 00352 FactorGraph::FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint, size_t nrVarHint ) : _G(), _backup() { 00353 // add factors 00354 size_t nrEdges = 0; 00355 _factors.reserve( nrFacHint ); 00356 for( FactorInputIterator p2 = facBegin; p2 != facEnd; ++p2 ) { 00357 _factors.push_back( *p2 ); 00358 nrEdges += p2->vars().size(); 00359 } 00360 00361 // add variables 00362 _vars.reserve( nrVarHint ); 00363 for( VarInputIterator p1 = varBegin; p1 != varEnd; ++p1 ) 00364 _vars.push_back( *p1 ); 00365 00366 // create graph structure 00367 constructGraph( nrEdges ); 00368 } 00369 00370 00377 } // end of namespace dai 00378 00379 00380 #endif