libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00009 #ifndef __defined_libdai_emalg_h 00010 #define __defined_libdai_emalg_h 00011 00012 00013 #include <vector> 00014 #include <map> 00015 00016 #include <dai/factor.h> 00017 #include <dai/daialg.h> 00018 #include <dai/evidence.h> 00019 #include <dai/index.h> 00020 #include <dai/properties.h> 00021 00022 00026 00027 00028 namespace dai { 00029 00030 00032 00050 class ParameterEstimation { 00051 public: 00053 typedef ParameterEstimation* (*ParamEstFactory)( const PropertySet& ); 00054 00056 virtual ~ParameterEstimation() {} 00057 00059 virtual ParameterEstimation* clone() const = 0; 00060 00062 00067 static ParameterEstimation* construct( const std::string &method, const PropertySet &p ); 00068 00070 static void registerMethod( const std::string &method, const ParamEstFactory &f ) { 00071 if( _registry == NULL ) 00072 loadDefaultRegistry(); 00073 (*_registry)[method] = f; 00074 } 00075 00077 virtual Prob estimate() = 0; 00078 00080 virtual void addSufficientStatistics( const Prob &p ) = 0; 00081 00083 virtual size_t probSize() const = 0; 00084 00085 private: 00087 static std::map<std::string, ParamEstFactory> *_registry; 00088 00090 static void loadDefaultRegistry(); 00091 }; 00092 00093 00095 00097 class CondProbEstimation : private ParameterEstimation { 00098 private: 00100 size_t _target_dim; 00102 Prob _stats; 00104 Prob _initial_stats; 00105 00106 public: 00108 00112 CondProbEstimation( size_t target_dimension, const Prob &pseudocounts ); 00113 00115 00123 static ParameterEstimation* factory( const PropertySet &p ); 00124 00126 virtual ParameterEstimation* clone() const { return new CondProbEstimation( _target_dim, _initial_stats ); } 00127 00129 virtual ~CondProbEstimation() {} 00130 00132 00135 virtual Prob estimate(); 00136 00138 virtual void addSufficientStatistics( const Prob &p ); 00139 00141 virtual size_t probSize() const { return _stats.size(); } 00142 }; 00143 00144 00146 00156 class SharedParameters { 00157 public: 00159 typedef size_t FactorIndex; 00161 typedef std::map<FactorIndex, std::vector<Var> > FactorOrientations; 00162 00163 private: 00165 std::map<FactorIndex, VarSet> _varsets; 00167 std::map<FactorIndex, Permute> _perms; 00169 FactorOrientations _varorders; 00171 ParameterEstimation *_estimation; 00173 bool _ownEstimation; 00174 00176 00180 static Permute calculatePermutation( const std::vector<Var> &varOrder, VarSet &outVS ); 00181 00183 void setPermsAndVarSetsFromVarOrders(); 00184 00185 public: 00187 00191 SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool ownPE=false ); 00192 00194 00197 SharedParameters( std::istream &is, const FactorGraph &fg ); 00198 00200 SharedParameters( const SharedParameters &sp ) : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _ownEstimation(sp._ownEstimation) { 00201 // If sp owns its _estimation object, we should clone it instead of copying the pointer 00202 if( _ownEstimation ) 00203 _estimation = _estimation->clone(); 00204 } 00205 00207 ~SharedParameters() { 00208 // If we own the _estimation object, we should delete it now 00209 if( _ownEstimation ) 00210 delete _estimation; 00211 } 00212 00214 00220 void collectSufficientStatistics( InfAlg &alg ); 00221 00223 00228 void setParameters( FactorGraph &fg ); 00229 }; 00230 00231 00233 00235 class MaximizationStep { 00236 private: 00238 std::vector<SharedParameters> _params; 00239 00240 public: 00242 MaximizationStep() : _params() {} 00243 00245 MaximizationStep( std::vector<SharedParameters> &maximizations ) : _params(maximizations) {} 00246 00248 00250 MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ); 00251 00253 void addExpectations( InfAlg &alg ); 00254 00256 void maximize( FactorGraph &fg ); 00257 00259 00260 00261 typedef std::vector<SharedParameters>::iterator iterator; 00263 typedef std::vector<SharedParameters>::const_iterator const_iterator; 00264 00266 iterator begin() { return _params.begin(); } 00268 const_iterator begin() const { return _params.begin(); } 00270 iterator end() { return _params.end(); } 00272 const_iterator end() const { return _params.end(); } 00274 }; 00275 00276 00278 00295 class EMAlg { 00296 private: 00298 const Evidence &_evidence; 00299 00301 InfAlg &_estep; 00302 00304 std::vector<MaximizationStep> _msteps; 00305 00307 size_t _iters; 00308 00310 std::vector<Real> _lastLogZ; 00311 00313 size_t _max_iters; 00314 00316 Real _log_z_tol; 00317 00318 public: 00320 static const std::string MAX_ITERS_KEY; 00322 static const size_t MAX_ITERS_DEFAULT; 00324 static const std::string LOG_Z_TOL_KEY; 00326 static const Real LOG_Z_TOL_DEFAULT; 00327 00329 00334 EMAlg( const Evidence &evidence, InfAlg &estep, std::vector<MaximizationStep> &msteps, const PropertySet &termconditions ) 00335 : _evidence(evidence), _estep(estep), _msteps(msteps), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT) 00336 { 00337 setTermConditions( termconditions ); 00338 } 00339 00341 00343 EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &mstep_file ); 00344 00346 00352 void setTermConditions( const PropertySet &p ); 00353 00355 00361 bool hasSatisfiedTermConditions() const; 00362 00364 Real logZ() const { return _lastLogZ.back(); } 00365 00367 size_t Iterations() const { return _iters; } 00368 00370 const InfAlg& eStep() const { return _estep; } 00371 00373 00375 Real iterate(); 00376 00378 Real iterate( MaximizationStep &mstep ); 00379 00381 void run(); 00382 00384 00385 00386 typedef std::vector<MaximizationStep>::iterator s_iterator; 00388 typedef std::vector<MaximizationStep>::const_iterator const_s_iterator; 00389 00391 s_iterator s_begin() { return _msteps.begin(); } 00393 const_s_iterator s_begin() const { return _msteps.begin(); } 00395 s_iterator s_end() { return _msteps.end(); } 00397 const_s_iterator s_end() const { return _msteps.end(); } 00399 }; 00400 00401 00402 } // end of namespace dai 00403 00404 00410 #endif