libDAI
include/dai/clustergraph.h
Go to the documentation of this file.
00001 /*  This file is part of libDAI - http://www.libdai.org/
00002  *
00003  *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
00004  *
00005  *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
00006  */
00007 
00008 
00013 
00014 
00015 #ifndef __defined_libdai_clustergraph_h
00016 #define __defined_libdai_clustergraph_h
00017 
00018 
00019 #include <set>
00020 #include <vector>
00021 #include <dai/varset.h>
00022 #include <dai/bipgraph.h>
00023 #include <dai/factorgraph.h>
00024 
00025 
00026 namespace dai {
00027 
00028 
00030 
00034     class ClusterGraph {
00035         private:
00037             BipartiteGraph       _G;
00038 
00040             std::vector<Var>     _vars;
00041 
00043             std::vector<VarSet>  _clusters;
00044 
00045         public:
00047 
00048 
00049             ClusterGraph() : _G(), _vars(), _clusters() {}
00050 
00052             ClusterGraph( const std::vector<VarSet>& cls );
00053 
00055 
00058             ClusterGraph( const FactorGraph& fg, bool onlyMaximal );
00060 
00062 
00063 
00064             const BipartiteGraph& bipGraph() const { return _G; }
00065 
00067             size_t nrVars() const { return _vars.size(); }
00068 
00070             const std::vector<Var>& vars() const { return _vars; }
00071 
00073             const Var& var( size_t i ) const {
00074                 DAI_DEBASSERT( i < nrVars() );
00075                 return _vars[i]; 
00076             }
00077 
00079             size_t nrClusters() const { return _clusters.size(); }
00080 
00082             const std::vector<VarSet>& clusters() const { return _clusters; }
00083 
00085             const VarSet& cluster( size_t I ) const {
00086                 DAI_DEBASSERT( I < nrClusters() );
00087                 return _clusters[I]; 
00088             }
00089 
00091             size_t findVar( const Var& n ) const {
00092                 return find( _vars.begin(), _vars.end(), n ) - _vars.begin();
00093             }
00094 
00096             size_t findCluster( const VarSet& cl ) const {
00097                 return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin();
00098             }
00099 
00101             VarSet Delta( size_t i ) const {
00102                 VarSet result;
00103                 foreach( const Neighbor& I, _G.nb1(i) )
00104                     result |= _clusters[I];
00105                 return result;
00106             }
00107 
00109             VarSet delta( size_t i ) const {
00110                 return Delta( i ) / _vars[i];
00111             }
00112 
00114             bool adj( size_t i1, size_t i2 ) const {
00115                 if( i1 == i2 )
00116                     return false;
00117                 bool result = false;
00118                 foreach( const Neighbor& I, _G.nb1(i1) )
00119                     if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
00120                         result = true;
00121                         break;
00122                     }
00123                 return result;
00124             }
00125 
00127             bool isMaximal( size_t I ) const {
00128                 DAI_DEBASSERT( I < _G.nrNodes2() );
00129                 const VarSet & clI = _clusters[I];
00130                 bool maximal = true;
00131                 // The following may not be optimal, since it may repeatedly test the same cluster *J
00132                 foreach( const Neighbor& i, _G.nb2(I) ) {
00133                     foreach( const Neighbor& J, _G.nb1(i) )
00134                         if( (J != I) && (clI << _clusters[J]) ) {
00135                             maximal = false;
00136                             break;
00137                         }
00138                     if( !maximal )
00139                         break;
00140                 }
00141                 return maximal;
00142             }
00144 
00146 
00147 
00148 
00151             size_t insert( const VarSet& cl ) {
00152                 size_t index = findCluster( cl );  // OPTIMIZE ME
00153                 if( index == _clusters.size() ) {
00154                     _clusters.push_back( cl );
00155                     // add variables (if necessary) and calculate neighborhood of new cluster
00156                     std::vector<size_t> nbs;
00157                     for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
00158                         size_t iter = findVar( *n );  // OPTIMIZE ME
00159                         nbs.push_back( iter );
00160                         if( iter == _vars.size() ) {
00161                             _G.addNode1();
00162                             _vars.push_back( *n );
00163                         }
00164                     }
00165                     _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
00166                 }
00167                 return index;
00168             }
00169 
00171             ClusterGraph& eraseNonMaximal() {
00172                 for( size_t I = 0; I < _G.nrNodes2(); ) {
00173                     if( !isMaximal(I) ) {
00174                         _clusters.erase( _clusters.begin() + I );
00175                         _G.eraseNode2(I);
00176                     } else
00177                         I++;
00178                 }
00179                 return *this;
00180             }
00181 
00183             ClusterGraph& eraseSubsuming( size_t i ) {
00184                 DAI_ASSERT( i < nrVars() );
00185                 while( _G.nb1(i).size() ) {
00186                     _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
00187                     _G.eraseNode2( _G.nb1(i)[0] );
00188                 }
00189                 return *this;
00190             }
00191 
00193 
00195             VarSet elimVar( size_t i ) {
00196                 DAI_ASSERT( i < nrVars() );
00197                 VarSet Di = Delta( i );
00198                 insert( Di / var(i) );
00199                 eraseSubsuming( i );
00200                 eraseNonMaximal();
00201                 return Di;
00202             }
00204 
00206 
00207 
00208             friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) {
00209                 os << cl.clusters();
00210                 return os;
00211             }
00213 
00215 
00216 
00217 
00223             template<class EliminationChoice>
00224             ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const {
00225                 // Make a copy
00226                 ClusterGraph cl(*this);
00227                 cl.eraseNonMaximal();
00228 
00229                 ClusterGraph result;
00230 
00231                 // Construct set of variable indices
00232                 std::set<size_t> varindices;
00233                 for( size_t i = 0; i < _vars.size(); ++i )
00234                     varindices.insert( i );
00235 
00236                 // Do variable elimination
00237                 BigInt totalStates = 0;
00238                 while( !varindices.empty() ) {
00239                     size_t i = f( cl, varindices );
00240                     VarSet Di = cl.elimVar( i );
00241                     result.insert( Di );
00242                     if( maxStates ) {
00243                         totalStates += Di.nrStates();
00244                         if( totalStates > maxStates )
00245                             DAI_THROW(OUT_OF_MEMORY);
00246                     }
00247                     varindices.erase( i );
00248                 }
00249 
00250                 return result;
00251             }
00253     };
00254 
00255 
00257 
00259     class sequentialVariableElimination {
00260         private:
00262             std::vector<Var> seq;
00264             size_t i;
00265 
00266         public:
00268             sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
00269 
00271            size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ );
00272     };
00273 
00274 
00276 
00279     class greedyVariableElimination {
00280         public:
00282             typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
00283 
00284         private:
00286             eliminationCostFunction heuristic;
00287 
00288         public:
00290 
00292             greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
00293 
00295 
00297             size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars );
00298     };
00299 
00300 
00302 
00306     size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i );
00307 
00308 
00310 
00315     size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i );
00316 
00317 
00319 
00323     size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i );
00324 
00325 
00327 
00332     size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i );
00333 
00334 
00335 } // end of namespace dai
00336 
00337 
00338 #endif