libDAI
include/dai/factorgraph.h
Go to the documentation of this file.
00001 /*  This file is part of libDAI - http://www.libdai.org/
00002  *
00003  *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
00004  *
00005  *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
00006  */
00007 
00008 
00011 
00012 
00013 #ifndef __defined_libdai_factorgraph_h
00014 #define __defined_libdai_factorgraph_h
00015 
00016 
00017 #include <iostream>
00018 #include <map>
00019 #include <dai/bipgraph.h>
00020 #include <dai/graph.h>
00021 #include <dai/factor.h>
00022 
00023 
00024 namespace dai {
00025 
00026 
00028 
00065 class FactorGraph {
00066     private:
00068         BipartiteGraph           _G;
00070         std::vector<Var>         _vars;
00072         std::vector<Factor>      _factors;
00074         std::map<size_t,Factor>  _backup;
00075 
00076     public:
00078 
00079 
00080         FactorGraph() : _G(), _vars(), _factors(), _backup() {}
00081 
00083         FactorGraph( const std::vector<Factor>& P );
00084 
00086 
00090         template<typename FactorInputIterator, typename VarInputIterator>
00091         FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint = 0, size_t nrVarHint = 0 );
00092 
00094         virtual ~FactorGraph() {}
00095 
00097         virtual FactorGraph* clone() const { return new FactorGraph(*this); }
00099 
00101 
00102 
00103         const Var& var( size_t i ) const { 
00104             DAI_DEBASSERT( i < nrVars() );
00105             return _vars[i]; 
00106         }
00107 
00109         const std::vector<Var>& vars() const { return _vars; }
00110 
00112         const Factor& factor( size_t I ) const { 
00113             DAI_DEBASSERT( I < nrFactors() );
00114             return _factors[I]; 
00115         }
00117         const std::vector<Factor>& factors() const { return _factors; }
00118 
00120         const Neighbors& nbV( size_t i ) const { return _G.nb1(i); }
00122         const Neighbors& nbF( size_t I ) const { return _G.nb2(I); }
00124         const Neighbor& nbV( size_t i, size_t _I ) const { return _G.nb1(i)[_I]; }
00126         const Neighbor& nbF( size_t I, size_t _i ) const { return _G.nb2(I)[_i]; }
00128 
00130 
00131 
00132         const BipartiteGraph& bipGraph() const { return _G; }
00134         size_t nrVars() const { return vars().size(); }
00136         size_t nrFactors() const { return factors().size(); }
00138 
00140         size_t nrEdges() const { return _G.nrEdges(); }
00141 
00143 
00146         size_t findVar( const Var& n ) const {
00147             size_t i = find( vars().begin(), vars().end(), n ) - vars().begin();
00148             if( i == nrVars() )
00149                 DAI_THROW(OBJECT_NOT_FOUND);
00150             return i;
00151         }
00152 
00154 
00157         SmallSet<size_t> findVars( const VarSet& ns ) const {
00158             SmallSet<size_t> result;
00159             for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
00160                 result.insert( findVar( *n ) );
00161             return result;
00162         }
00163 
00165 
00168         size_t findFactor( const VarSet& ns ) const {
00169             size_t I;
00170             for( I = 0; I < nrFactors(); I++ )
00171                 if( factor(I).vars() == ns )
00172                     break;
00173             if( I == nrFactors() )
00174                 DAI_THROW(OBJECT_NOT_FOUND);
00175             return I;
00176         }
00177 
00179         VarSet Delta( size_t i ) const;
00180 
00182         VarSet Delta( const VarSet& vs ) const;
00183 
00185         VarSet delta( size_t i ) const {
00186             return( Delta( i ) / var( i ) );
00187         }
00188 
00190         VarSet delta( const VarSet& vs ) const {
00191             return Delta( vs ) / vs;
00192         }
00193 
00195         bool isConnected() const { return _G.isConnected(); }
00196 
00198         bool isTree() const { return _G.isTree(); }
00199 
00201         bool isPairwise() const;
00202 
00204         bool isBinary() const;
00205 
00207 
00210         GraphAL MarkovGraph() const;
00211 
00213 
00216         bool isMaximal( size_t I ) const;
00217 
00219 
00222         size_t maximalFactor( size_t I ) const;
00223 
00225 
00228         std::vector<VarSet> maximalFactorDomains() const;
00229 
00231         Real logScore( const std::vector<size_t>& statevec ) const;
00233 
00235 
00236 
00237         virtual void setFactor( size_t I, const Factor& newFactor, bool backup = false ) {
00238             DAI_ASSERT( newFactor.vars() == factor(I).vars() );
00239             if( backup )
00240                 backupFactor( I );
00241             _factors[I] = newFactor;
00242         }
00243 
00245         virtual void setFactors( const std::map<size_t, Factor>& facs, bool backup = false ) {
00246             for( std::map<size_t, Factor>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ ) {
00247                 if( backup )
00248                     backupFactor( fac->first );
00249                 setFactor( fac->first, fac->second );
00250             }
00251         }
00252 
00254 
00256         void backupFactor( size_t I );
00257 
00259 
00261         void restoreFactor( size_t I );
00262 
00264 
00266         virtual void backupFactors( const std::set<size_t>& facs );
00267 
00269         virtual void restoreFactors();
00270 
00272 
00274         void backupFactors( const VarSet& ns );
00275 
00277         void restoreFactors( const VarSet& ns );
00279 
00281 
00282 
00283         FactorGraph maximalFactors() const;
00284 
00286 
00289         FactorGraph clamped( size_t i, size_t x ) const;
00291 
00293 
00294 
00295 
00297         virtual void clamp( size_t i, size_t x, bool backup = false );
00298 
00300 
00302         void clampVar( size_t i, const std::vector<size_t>& xis, bool backup = false );
00303 
00305 
00307         void clampFactor( size_t I, const std::vector<size_t>& xIs, bool backup = false );
00308 
00310 
00312         virtual void makeCavity( size_t i, bool backup = false );
00314 
00316 
00317 
00318 
00322         virtual void ReadFromFile( const char *filename );
00323 
00325 
00328         virtual void WriteToFile( const char *filename, size_t precision=15 ) const;
00329 
00331 
00333         friend std::ostream& operator<< (std::ostream& os, const FactorGraph& fg );
00334 
00336 
00339         friend std::istream& operator>> (std::istream& is, FactorGraph& fg );
00340 
00342         virtual void printDot( std::ostream& os ) const;
00344 
00345     private:
00347         void constructGraph( size_t nrEdges );
00348 };
00349 
00350 
00351 template<typename FactorInputIterator, typename VarInputIterator>
00352 FactorGraph::FactorGraph(FactorInputIterator facBegin, FactorInputIterator facEnd, VarInputIterator varBegin, VarInputIterator varEnd, size_t nrFacHint, size_t nrVarHint ) : _G(), _backup() {
00353     // add factors
00354     size_t nrEdges = 0;
00355     _factors.reserve( nrFacHint );
00356     for( FactorInputIterator p2 = facBegin; p2 != facEnd; ++p2 ) {
00357         _factors.push_back( *p2 );
00358         nrEdges += p2->vars().size();
00359     }
00360 
00361     // add variables
00362     _vars.reserve( nrVarHint );
00363     for( VarInputIterator p1 = varBegin; p1 != varEnd; ++p1 )
00364         _vars.push_back( *p1 );
00365 
00366     // create graph structure
00367     constructGraph( nrEdges );
00368 }
00369 
00370 
00377 } // end of namespace dai
00378 
00379 
00380 #endif