libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00011 00012 00013 #ifndef __defined_libdai_regiongraph_h 00014 #define __defined_libdai_regiongraph_h 00015 00016 00017 #include <iostream> 00018 #include <dai/bipgraph.h> 00019 #include <dai/factorgraph.h> 00020 #include <dai/weightedgraph.h> 00021 00022 00023 namespace dai { 00024 00025 00027 class Region : public VarSet { 00028 private: 00030 Real _c; 00031 00032 public: 00034 Region() : VarSet(), _c(1.0) {} 00035 00037 Region( const VarSet& x, Real c ) : VarSet(x), _c(c) {} 00038 00040 const Real& c() const { return _c; } 00041 00043 Real& c() { return _c; } 00044 }; 00045 00046 00048 class FRegion : public Factor { 00049 private: 00051 Real _c; 00052 00053 public: 00055 FRegion() : Factor(), _c(1.0) {} 00056 00058 FRegion( const Factor& x, Real c ) : Factor(x), _c(c) {} 00059 00061 const Real& c() const { return _c; } 00062 00064 Real& c() { return _c; } 00065 }; 00066 00067 00069 00091 class RegionGraph : public FactorGraph { 00092 protected: 00094 BipartiteGraph _G; 00095 00097 std::vector<FRegion> _ORs; 00098 00100 std::vector<Region> _IRs; 00101 00103 std::vector<size_t> _fac2OR; 00104 00105 00106 public: 00108 00109 00110 RegionGraph() : FactorGraph(), _G(), _ORs(), _IRs(), _fac2OR() {} 00111 00113 00115 RegionGraph( const FactorGraph& fg, const std::vector<VarSet>& ors, const std::vector<Region>& irs, const std::vector<std::pair<size_t,size_t> >& edges ) : FactorGraph(), _G(), _ORs(), _IRs(), _fac2OR() { 00116 construct( fg, ors, irs, edges ); 00117 00118 // Check counting numbers 00119 #ifdef DAI_DEBUG 00120 checkCountingNumbers(); 00121 #endif 00122 } 00123 00125 00136 RegionGraph( const FactorGraph& fg, const std::vector<VarSet>& cl ) : FactorGraph(), _G(), _ORs(), _IRs(), _fac2OR() { 00137 constructCVM( fg, cl ); 00138 00139 // Check counting numbers 00140 #ifdef DAI_DEBUG 00141 checkCountingNumbers(); 00142 #endif 00143 } 00144 00146 virtual RegionGraph* clone() const { return new RegionGraph(*this); } 00148 00150 00151 00152 size_t nrORs() const { return _ORs.size(); } 00154 size_t nrIRs() const { return _IRs.size(); } 00155 00157 const FRegion& OR( size_t alpha ) const { 00158 DAI_DEBASSERT( alpha < nrORs() ); 00159 return _ORs[alpha]; 00160 } 00162 FRegion& OR( size_t alpha ) { 00163 DAI_DEBASSERT( alpha < nrORs() ); 00164 return _ORs[alpha]; 00165 } 00166 00168 const Region& IR( size_t beta ) const { 00169 DAI_DEBASSERT( beta < nrIRs() ); 00170 return _IRs[beta]; 00171 } 00173 Region& IR( size_t beta ) { 00174 DAI_DEBASSERT( beta < nrIRs() ); 00175 return _IRs[beta]; 00176 } 00177 00179 size_t fac2OR( size_t I ) const { 00180 DAI_DEBASSERT( I < nrFactors() ); 00181 DAI_DEBASSERT( I < _fac2OR.size() ); 00182 return _fac2OR[I]; 00183 } 00184 00186 const Neighbors& nbOR( size_t alpha ) const { return _G.nb1(alpha); } 00187 00189 const Neighbors& nbIR( size_t beta ) const { return _G.nb2(beta); } 00190 00192 00196 const BipartiteGraph& DAG() const { return _G; } 00198 00200 00201 00202 00207 bool checkCountingNumbers() const; 00209 00211 00212 00213 virtual void setFactor( size_t I, const Factor& newFactor, bool backup = false ) { 00214 FactorGraph::setFactor( I, newFactor, backup ); 00215 recomputeOR( I ); 00216 } 00217 00219 virtual void setFactors( const std::map<size_t, Factor>& facs, bool backup = false ) { 00220 FactorGraph::setFactors( facs, backup ); 00221 VarSet ns; 00222 for( std::map<size_t, Factor>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ ) 00223 ns |= fac->second.vars(); 00224 recomputeORs( ns ); 00225 } 00227 00229 00230 00231 00233 virtual void ReadFromFile( const char* /*filename*/ ) { 00234 DAI_THROW(NOT_IMPLEMENTED); 00235 } 00236 00238 00240 virtual void WriteToFile( const char* /*filename*/, size_t /*precision*/=15 ) const { 00241 DAI_THROW(NOT_IMPLEMENTED); 00242 } 00243 00245 friend std::ostream& operator<< ( std::ostream& os, const RegionGraph& rg ); 00246 00248 00250 virtual void printDot( std::ostream& /*os*/ ) const { 00251 DAI_THROW(NOT_IMPLEMENTED); 00252 } 00254 00255 protected: 00257 void construct( const FactorGraph& fg, const std::vector<VarSet>& ors, const std::vector<Region>& irs, const std::vector<std::pair<size_t,size_t> >& edges ); 00258 00260 void constructCVM( const FactorGraph& fg, const std::vector<VarSet>& cl, size_t verbose=0 ); 00261 00263 00265 void recomputeORs(); 00266 00268 00270 void recomputeORs( const VarSet& vs ); 00271 00273 00275 void recomputeOR( size_t I ); 00276 00278 00284 void calcCVMCountingNumbers(); 00285 00286 }; 00287 00288 00289 } // end of namespace dai 00290 00291 00292 #endif