libDAI
include/dai/emalg.h
Go to the documentation of this file.
00001 /*  This file is part of libDAI - http://www.libdai.org/
00002  *
00003  *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
00004  *
00005  *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
00006  */
00007 
00008 
00009 #ifndef __defined_libdai_emalg_h
00010 #define __defined_libdai_emalg_h
00011 
00012 
00013 #include <vector>
00014 #include <map>
00015 
00016 #include <dai/factor.h>
00017 #include <dai/daialg.h>
00018 #include <dai/evidence.h>
00019 #include <dai/index.h>
00020 #include <dai/properties.h>
00021 
00022 
00026 
00027 
00028 namespace dai {
00029 
00030 
00032 
00050 class ParameterEstimation {
00051     public:
00053         typedef ParameterEstimation* (*ParamEstFactory)( const PropertySet& );
00054 
00056         virtual ~ParameterEstimation() {}
00057 
00059         virtual ParameterEstimation* clone() const = 0;
00060 
00062 
00067         static ParameterEstimation* construct( const std::string &method, const PropertySet &p );
00068 
00070         static void registerMethod( const std::string &method, const ParamEstFactory &f ) {
00071             if( _registry == NULL )
00072                 loadDefaultRegistry();
00073             (*_registry)[method] = f;
00074         }
00075 
00077         virtual Prob estimate() = 0;
00078 
00080         virtual void addSufficientStatistics( const Prob &p ) = 0;
00081 
00083         virtual size_t probSize() const = 0;
00084 
00085     private:
00087         static std::map<std::string, ParamEstFactory> *_registry;
00088 
00090         static void loadDefaultRegistry();
00091 };
00092 
00093 
00095 
00097 class CondProbEstimation : private ParameterEstimation {
00098     private:
00100         size_t _target_dim;
00102         Prob _stats;
00104         Prob _initial_stats;
00105 
00106     public:
00108 
00112         CondProbEstimation( size_t target_dimension, const Prob &pseudocounts );
00113 
00115 
00123         static ParameterEstimation* factory( const PropertySet &p );
00124 
00126         virtual ParameterEstimation* clone() const { return new CondProbEstimation( _target_dim, _initial_stats ); }
00127 
00129         virtual ~CondProbEstimation() {}
00130 
00132 
00135         virtual Prob estimate();
00136 
00138         virtual void addSufficientStatistics( const Prob &p );
00139 
00141         virtual size_t probSize() const { return _stats.size(); }
00142 };
00143 
00144 
00146 
00156 class SharedParameters {
00157     public:
00159         typedef size_t FactorIndex;
00161         typedef std::map<FactorIndex, std::vector<Var> > FactorOrientations;
00162 
00163     private:
00165         std::map<FactorIndex, VarSet> _varsets;
00167         std::map<FactorIndex, Permute> _perms;
00169         FactorOrientations _varorders;
00171         ParameterEstimation *_estimation;
00173         bool _ownEstimation;
00174 
00176 
00180         static Permute calculatePermutation( const std::vector<Var> &varOrder, VarSet &outVS );
00181 
00183         void setPermsAndVarSetsFromVarOrders();
00184 
00185     public:
00187 
00191         SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool ownPE=false );
00192 
00194 
00197         SharedParameters( std::istream &is, const FactorGraph &fg );
00198 
00200         SharedParameters( const SharedParameters &sp ) : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _ownEstimation(sp._ownEstimation) {
00201             // If sp owns its _estimation object, we should clone it instead of copying the pointer
00202             if( _ownEstimation )
00203                 _estimation = _estimation->clone();
00204         }
00205 
00207         ~SharedParameters() {
00208             // If we own the _estimation object, we should delete it now
00209             if( _ownEstimation )
00210                 delete _estimation;
00211         }
00212 
00214 
00220         void collectSufficientStatistics( InfAlg &alg );
00221 
00223 
00228         void setParameters( FactorGraph &fg );
00229 };
00230 
00231 
00233 
00235 class MaximizationStep {
00236     private:
00238         std::vector<SharedParameters> _params;
00239 
00240     public:
00242         MaximizationStep() : _params() {}
00243 
00245         MaximizationStep( std::vector<SharedParameters> &maximizations ) : _params(maximizations) {}
00246 
00248 
00250         MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup );
00251 
00253         void addExpectations( InfAlg &alg );
00254 
00256         void maximize( FactorGraph &fg );
00257 
00259 
00260 
00261         typedef std::vector<SharedParameters>::iterator iterator;
00263         typedef std::vector<SharedParameters>::const_iterator const_iterator;
00264 
00266         iterator begin() { return _params.begin(); }
00268         const_iterator begin() const { return _params.begin(); }
00270         iterator end() { return _params.end(); }
00272         const_iterator end() const { return _params.end(); }
00274 };
00275 
00276 
00278 
00295 class EMAlg {
00296     private:
00298         const Evidence &_evidence;
00299 
00301         InfAlg &_estep;
00302 
00304         std::vector<MaximizationStep> _msteps;
00305 
00307         size_t _iters;
00308 
00310         std::vector<Real> _lastLogZ;
00311 
00313         size_t _max_iters;
00314 
00316         Real _log_z_tol;
00317 
00318     public:
00320         static const std::string MAX_ITERS_KEY;
00322         static const size_t MAX_ITERS_DEFAULT;
00324         static const std::string LOG_Z_TOL_KEY;
00326         static const Real LOG_Z_TOL_DEFAULT;
00327 
00329 
00334         EMAlg( const Evidence &evidence, InfAlg &estep, std::vector<MaximizationStep> &msteps, const PropertySet &termconditions )
00335           : _evidence(evidence), _estep(estep), _msteps(msteps), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
00336         {
00337               setTermConditions( termconditions );
00338         }
00339 
00341 
00343         EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &mstep_file );
00344 
00346 
00352         void setTermConditions( const PropertySet &p );
00353 
00355 
00361         bool hasSatisfiedTermConditions() const;
00362 
00364         Real logZ() const { return _lastLogZ.back(); }
00365 
00367         size_t Iterations() const { return _iters; }
00368 
00370         const InfAlg& eStep() const { return _estep; }
00371 
00373 
00375         Real iterate();
00376 
00378         Real iterate( MaximizationStep &mstep );
00379 
00381         void run();
00382 
00384 
00385 
00386         typedef std::vector<MaximizationStep>::iterator s_iterator;
00388         typedef std::vector<MaximizationStep>::const_iterator const_s_iterator;
00389 
00391         s_iterator s_begin() { return _msteps.begin(); }
00393         const_s_iterator s_begin() const { return _msteps.begin(); }
00395         s_iterator s_end() { return _msteps.end(); }
00397         const_s_iterator s_end() const { return _msteps.end(); }
00399 };
00400 
00401 
00402 } // end of namespace dai
00403 
00404 
00410 #endif