libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00013 00014 00015 #ifndef __defined_libdai_treeep_h 00016 #define __defined_libdai_treeep_h 00017 00018 00019 #include <vector> 00020 #include <string> 00021 #include <dai/daialg.h> 00022 #include <dai/varset.h> 00023 #include <dai/regiongraph.h> 00024 #include <dai/factorgraph.h> 00025 #include <dai/clustergraph.h> 00026 #include <dai/weightedgraph.h> 00027 #include <dai/jtree.h> 00028 #include <dai/properties.h> 00029 #include <dai/enum.h> 00030 00031 00032 namespace dai { 00033 00034 00036 class TreeEP : public JTree { 00037 private: 00039 Real _maxdiff; 00041 size_t _iters; 00042 00043 public: 00045 struct Properties { 00047 00053 DAI_ENUM(TypeType,ORG,ALT); 00054 00056 size_t verbose; 00057 00059 size_t maxiter; 00060 00062 double maxtime; 00063 00065 Real tol; 00066 00068 TypeType type; 00069 } props; 00070 00071 private: 00073 00077 class TreeEPSubTree { 00078 private: 00080 std::vector<Factor> _Qa; 00082 std::vector<Factor> _Qb; 00084 RootedTree _RTree; 00086 std::vector<size_t> _a; 00088 std::vector<size_t> _b; 00090 const Factor * _I; 00092 VarSet _ns; 00094 VarSet _nsrem; 00096 Real _logZ; 00097 00098 public: 00100 00101 00102 TreeEPSubTree() : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(NULL), _ns(), _nsrem(), _logZ(0.0) {} 00103 00105 TreeEPSubTree( const TreeEPSubTree &x ) : _Qa(x._Qa), _Qb(x._Qb), _RTree(x._RTree), _a(x._a), _b(x._b), _I(x._I), _ns(x._ns), _nsrem(x._nsrem), _logZ(x._logZ) {} 00106 00108 TreeEPSubTree & operator=( const TreeEPSubTree& x ) { 00109 if( this != &x ) { 00110 _Qa = x._Qa; 00111 _Qb = x._Qb; 00112 _RTree = x._RTree; 00113 _a = x._a; 00114 _b = x._b; 00115 _I = x._I; 00116 _ns = x._ns; 00117 _nsrem = x._nsrem; 00118 _logZ = x._logZ; 00119 } 00120 return *this; 00121 } 00122 00124 TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ); 00126 00128 void init(); 00129 00131 void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ); 00132 00134 void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ); 00135 00137 Real logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const; 00138 00140 const Factor *& I() { return _I; } 00141 }; 00142 00144 std::map<size_t, TreeEPSubTree> _Q; 00145 00146 public: 00148 TreeEP() : JTree(), _maxdiff(0.0), _iters(0), props(), _Q() {} 00149 00151 TreeEP( const TreeEP &x ) : JTree(x), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props), _Q(x._Q) { 00152 for( size_t I = 0; I < nrFactors(); I++ ) 00153 if( offtree( I ) ) 00154 _Q[I].I() = &factor(I); 00155 } 00156 00158 TreeEP& operator=( const TreeEP &x ) { 00159 if( this != &x ) { 00160 JTree::operator=( x ); 00161 _maxdiff = x._maxdiff; 00162 _iters = x._iters; 00163 props = x.props; 00164 _Q = x._Q; 00165 for( size_t I = 0; I < nrFactors(); I++ ) 00166 if( offtree( I ) ) 00167 _Q[I].I() = &factor(I); 00168 } 00169 return *this; 00170 } 00171 00173 00176 TreeEP( const FactorGraph &fg, const PropertySet &opts ); 00177 00178 00180 00181 virtual TreeEP* clone() const { return new TreeEP(*this); } 00182 virtual TreeEP* construct( const FactorGraph &fg, const PropertySet &opts ) const { return new TreeEP( fg, opts ); } 00183 virtual std::string name() const { return "TREEEP"; } 00184 virtual Real logZ() const; 00185 virtual void init(); 00186 virtual void init( const VarSet &/*ns*/ ) { init(); } 00187 virtual Real run(); 00188 virtual Real maxDiff() const { return _maxdiff; } 00189 virtual size_t Iterations() const { return _iters; } 00190 virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; } 00191 virtual void setProperties( const PropertySet &opts ); 00192 virtual PropertySet getProperties() const; 00193 virtual std::string printProperties() const; 00195 00196 private: 00198 void construct( const FactorGraph& fg, const RootedTree& tree ); 00200 bool offtree( size_t I ) const { return (fac2OR(I) == -1U); } 00201 }; 00202 00203 00204 } // end of namespace dai 00205 00206 00207 #endif