libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00013 00014 00015 #ifndef __defined_libdai_clustergraph_h 00016 #define __defined_libdai_clustergraph_h 00017 00018 00019 #include <set> 00020 #include <vector> 00021 #include <dai/varset.h> 00022 #include <dai/bipgraph.h> 00023 #include <dai/factorgraph.h> 00024 00025 00026 namespace dai { 00027 00028 00030 00034 class ClusterGraph { 00035 private: 00037 BipartiteGraph _G; 00038 00040 std::vector<Var> _vars; 00041 00043 std::vector<VarSet> _clusters; 00044 00045 public: 00047 00048 00049 ClusterGraph() : _G(), _vars(), _clusters() {} 00050 00052 ClusterGraph( const std::vector<VarSet>& cls ); 00053 00055 00058 ClusterGraph( const FactorGraph& fg, bool onlyMaximal ); 00060 00062 00063 00064 const BipartiteGraph& bipGraph() const { return _G; } 00065 00067 size_t nrVars() const { return _vars.size(); } 00068 00070 const std::vector<Var>& vars() const { return _vars; } 00071 00073 const Var& var( size_t i ) const { 00074 DAI_DEBASSERT( i < nrVars() ); 00075 return _vars[i]; 00076 } 00077 00079 size_t nrClusters() const { return _clusters.size(); } 00080 00082 const std::vector<VarSet>& clusters() const { return _clusters; } 00083 00085 const VarSet& cluster( size_t I ) const { 00086 DAI_DEBASSERT( I < nrClusters() ); 00087 return _clusters[I]; 00088 } 00089 00091 size_t findVar( const Var& n ) const { 00092 return find( _vars.begin(), _vars.end(), n ) - _vars.begin(); 00093 } 00094 00096 size_t findCluster( const VarSet& cl ) const { 00097 return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin(); 00098 } 00099 00101 VarSet Delta( size_t i ) const { 00102 VarSet result; 00103 foreach( const Neighbor& I, _G.nb1(i) ) 00104 result |= _clusters[I]; 00105 return result; 00106 } 00107 00109 VarSet delta( size_t i ) const { 00110 return Delta( i ) / _vars[i]; 00111 } 00112 00114 bool adj( size_t i1, size_t i2 ) const { 00115 if( i1 == i2 ) 00116 return false; 00117 bool result = false; 00118 foreach( const Neighbor& I, _G.nb1(i1) ) 00119 if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) { 00120 result = true; 00121 break; 00122 } 00123 return result; 00124 } 00125 00127 bool isMaximal( size_t I ) const { 00128 DAI_DEBASSERT( I < _G.nrNodes2() ); 00129 const VarSet & clI = _clusters[I]; 00130 bool maximal = true; 00131 // The following may not be optimal, since it may repeatedly test the same cluster *J 00132 foreach( const Neighbor& i, _G.nb2(I) ) { 00133 foreach( const Neighbor& J, _G.nb1(i) ) 00134 if( (J != I) && (clI << _clusters[J]) ) { 00135 maximal = false; 00136 break; 00137 } 00138 if( !maximal ) 00139 break; 00140 } 00141 return maximal; 00142 } 00144 00146 00147 00148 00151 size_t insert( const VarSet& cl ) { 00152 size_t index = findCluster( cl ); // OPTIMIZE ME 00153 if( index == _clusters.size() ) { 00154 _clusters.push_back( cl ); 00155 // add variables (if necessary) and calculate neighborhood of new cluster 00156 std::vector<size_t> nbs; 00157 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) { 00158 size_t iter = findVar( *n ); // OPTIMIZE ME 00159 nbs.push_back( iter ); 00160 if( iter == _vars.size() ) { 00161 _G.addNode1(); 00162 _vars.push_back( *n ); 00163 } 00164 } 00165 _G.addNode2( nbs.begin(), nbs.end(), nbs.size() ); 00166 } 00167 return index; 00168 } 00169 00171 ClusterGraph& eraseNonMaximal() { 00172 for( size_t I = 0; I < _G.nrNodes2(); ) { 00173 if( !isMaximal(I) ) { 00174 _clusters.erase( _clusters.begin() + I ); 00175 _G.eraseNode2(I); 00176 } else 00177 I++; 00178 } 00179 return *this; 00180 } 00181 00183 ClusterGraph& eraseSubsuming( size_t i ) { 00184 DAI_ASSERT( i < nrVars() ); 00185 while( _G.nb1(i).size() ) { 00186 _clusters.erase( _clusters.begin() + _G.nb1(i)[0] ); 00187 _G.eraseNode2( _G.nb1(i)[0] ); 00188 } 00189 return *this; 00190 } 00191 00193 00195 VarSet elimVar( size_t i ) { 00196 DAI_ASSERT( i < nrVars() ); 00197 VarSet Di = Delta( i ); 00198 insert( Di / var(i) ); 00199 eraseSubsuming( i ); 00200 eraseNonMaximal(); 00201 return Di; 00202 } 00204 00206 00207 00208 friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) { 00209 os << cl.clusters(); 00210 return os; 00211 } 00213 00215 00216 00217 00223 template<class EliminationChoice> 00224 ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const { 00225 // Make a copy 00226 ClusterGraph cl(*this); 00227 cl.eraseNonMaximal(); 00228 00229 ClusterGraph result; 00230 00231 // Construct set of variable indices 00232 std::set<size_t> varindices; 00233 for( size_t i = 0; i < _vars.size(); ++i ) 00234 varindices.insert( i ); 00235 00236 // Do variable elimination 00237 BigInt totalStates = 0; 00238 while( !varindices.empty() ) { 00239 size_t i = f( cl, varindices ); 00240 VarSet Di = cl.elimVar( i ); 00241 result.insert( Di ); 00242 if( maxStates ) { 00243 totalStates += Di.nrStates(); 00244 if( totalStates > maxStates ) 00245 DAI_THROW(OUT_OF_MEMORY); 00246 } 00247 varindices.erase( i ); 00248 } 00249 00250 return result; 00251 } 00253 }; 00254 00255 00257 00259 class sequentialVariableElimination { 00260 private: 00262 std::vector<Var> seq; 00264 size_t i; 00265 00266 public: 00268 sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {} 00269 00271 size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ ); 00272 }; 00273 00274 00276 00279 class greedyVariableElimination { 00280 public: 00282 typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t); 00283 00284 private: 00286 eliminationCostFunction heuristic; 00287 00288 public: 00290 00292 greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {} 00293 00295 00297 size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars ); 00298 }; 00299 00300 00302 00306 size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i ); 00307 00308 00310 00315 size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i ); 00316 00317 00319 00323 size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i ); 00324 00325 00327 00332 size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i ); 00333 00334 00335 } // end of namespace dai 00336 00337 00338 #endif