libDAI
include/dai/index.h
Go to the documentation of this file.
00001 /*  This file is part of libDAI - http://www.libdai.org/
00002  *
00003  *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
00004  *
00005  *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
00006  */
00007 
00008 
00011 
00012 
00013 #ifndef __defined_libdai_index_h
00014 #define __defined_libdai_index_h
00015 
00016 
00017 #include <vector>
00018 #include <algorithm>
00019 #include <map>
00020 #include <dai/varset.h>
00021 
00022 
00023 namespace dai {
00024 
00025 
00027 
00048 class IndexFor {
00049     private:
00051         long                _index;
00052 
00054         std::vector<long>   _sum;
00055 
00057         std::vector<size_t> _state;
00058 
00060         std::vector<size_t> _ranges;
00061 
00062     public:
00064         IndexFor() : _index(-1) {}
00065 
00067         IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _state( forVars.size(), 0 ) {
00068             long sum = 1;
00069 
00070             _ranges.reserve( forVars.size() );
00071             _sum.reserve( forVars.size() );
00072 
00073             VarSet::const_iterator j = forVars.begin();
00074             for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
00075                 for( ; j != forVars.end() && *j <= *i; ++j ) {
00076                     _ranges.push_back( j->states() );
00077                     _sum.push_back( (*i == *j) ? sum : 0 );
00078                 }
00079                 sum *= i->states();
00080             }
00081             for( ; j != forVars.end(); ++j ) {
00082                 _ranges.push_back( j->states() );
00083                 _sum.push_back( 0 );
00084             }
00085             _index = 0;
00086         }
00087 
00089         IndexFor& reset() {
00090             fill( _state.begin(), _state.end(), 0 );
00091             _index = 0;
00092             return( *this );
00093         }
00094 
00096         operator size_t() const {
00097             DAI_ASSERT( valid() );
00098             return( _index );
00099         }
00100 
00102         IndexFor& operator++ () {
00103             if( _index >= 0 ) {
00104                 size_t i = 0;
00105 
00106                 while( i < _state.size() ) {
00107                     _index += _sum[i];
00108                     if( ++_state[i] < _ranges[i] )
00109                         break;
00110                     _index -= _sum[i] * _ranges[i];
00111                     _state[i] = 0;
00112                     i++;
00113                 }
00114 
00115                 if( i == _state.size() )
00116                     _index = -1;
00117             }
00118             return( *this );
00119         }
00120 
00122         void operator++( int ) {
00123             operator++();
00124         }
00125 
00127         bool valid() const {
00128             return( _index >= 0 );
00129         }
00130 };
00131 
00132 
00134 
00137 class Permute {
00138     private:
00140         std::vector<size_t>  _ranges;
00142         std::vector<size_t>  _sigma;
00143 
00144     public:
00146         Permute() : _ranges(), _sigma() {}
00147 
00149         Permute( const std::vector<size_t> &rs, const std::vector<size_t> &sigma ) : _ranges(rs), _sigma(sigma) {
00150             DAI_ASSERT( _ranges.size() == _sigma.size() );
00151         }
00152 
00154 
00158         Permute( const std::vector<Var> &vars, bool reverse=false ) : _ranges(), _sigma() {
00159             size_t N = vars.size();
00160 
00161             // construct ranges
00162             _ranges.reserve( N );
00163             for( size_t i = 0; i < N; ++i )
00164                 if( reverse )
00165                     _ranges.push_back( vars[N - 1 - i].states() );
00166                 else
00167                     _ranges.push_back( vars[i].states() );
00168 
00169             // construct VarSet out of vars
00170             VarSet vs( vars.begin(), vars.end(), N );
00171             DAI_ASSERT( vs.size() == N );
00172             
00173             // construct sigma
00174             _sigma.reserve( N );
00175             for( VarSet::const_iterator vs_i = vs.begin(); vs_i != vs.end(); ++vs_i ) {
00176                 size_t ind = find( vars.begin(), vars.end(), *vs_i ) - vars.begin();
00177                 if( reverse )
00178                     _sigma.push_back( N - 1 - ind );
00179                 else
00180                     _sigma.push_back( ind );
00181             }
00182         }
00183 
00185 
00188         size_t convertLinearIndex( size_t li ) const {
00189             size_t N = _ranges.size();
00190 
00191             // calculate vector index corresponding to linear index
00192             std::vector<size_t> vi;
00193             vi.reserve( N );
00194             size_t prod = 1;
00195             for( size_t k = 0; k < N; k++ ) {
00196                 vi.push_back( li % _ranges[k] );
00197                 li /= _ranges[k];
00198                 prod *= _ranges[k];
00199             }
00200 
00201             // convert permuted vector index to corresponding linear index
00202             prod = 1;
00203             size_t sigma_li = 0;
00204             for( size_t k = 0; k < N; k++ ) {
00205                 sigma_li += vi[_sigma[k]] * prod;
00206                 prod *= _ranges[_sigma[k]];
00207             }
00208 
00209             return sigma_li;
00210         }
00211 
00213         const std::vector<size_t>& sigma() const { return _sigma; }
00214 
00216         std::vector<size_t>& sigma() { return _sigma; }
00217 
00219         const std::vector<size_t>& ranges() { return _ranges; }
00220 
00222         size_t operator[]( size_t i ) const {
00223 #ifdef DAI_DEBUG
00224             return _sigma.at(i);
00225 #else
00226             return _sigma[i];
00227 #endif
00228         }
00229 
00231         Permute inverse() const {
00232             size_t N = _ranges.size();
00233             std::vector<size_t> invRanges( N, 0 );
00234             std::vector<size_t> invSigma( N, 0 );
00235             for( size_t i = 0; i < N; i++ ) {
00236                 invSigma[_sigma[i]] = i;
00237                 invRanges[i] = _ranges[_sigma[i]];
00238             }
00239             return Permute( invRanges, invSigma );
00240         }
00241 };
00242 
00243 
00245 
00263 class multifor {
00264     private:
00266         std::vector<size_t>  _ranges;
00268         std::vector<size_t>  _indices;
00270         long                 _linear_index;
00271 
00272     public:
00274         multifor() : _ranges(), _indices(), _linear_index(0) {}
00275 
00277         multifor( const std::vector<size_t> &d ) : _ranges(d), _indices(d.size(),0), _linear_index(0) {}
00278 
00280         operator size_t() const {
00281             DAI_DEBASSERT( valid() );
00282             return( _linear_index );
00283         }
00284 
00286         size_t operator[]( size_t k ) const {
00287             DAI_DEBASSERT( valid() );
00288             DAI_DEBASSERT( k < _indices.size() );
00289             return _indices[k];
00290         }
00291 
00293         multifor & operator++() {
00294             if( valid() ) {
00295                 _linear_index++;
00296                 size_t i;
00297                 for( i = 0; i != _indices.size(); i++ ) {
00298                     if( ++(_indices[i]) < _ranges[i] )
00299                         break;
00300                     _indices[i] = 0;
00301                 }
00302                 if( i == _indices.size() )
00303                     _linear_index = -1;
00304             }
00305             return *this;
00306         }
00307 
00309         void operator++( int ) {
00310             operator++();
00311         }
00312 
00314         multifor& reset() {
00315             fill( _indices.begin(), _indices.end(), 0 );
00316             _linear_index = 0;
00317             return( *this );
00318         }
00319 
00321         bool valid() const {
00322             return( _linear_index >= 0 );
00323         }
00324 };
00325 
00326 
00328 
00352 class State {
00353     private:
00355         typedef std::map<Var, size_t> states_type;
00356 
00358         BigInt                        state;
00359 
00361         states_type                   states;
00362 
00363     public:
00365         State() : state(0), states() {}
00366 
00368         State( const VarSet &vs, BigInt linearState=0 ) : state(linearState), states() {
00369             if( linearState == 0 )
00370                 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
00371                     states[*v] = 0;
00372             else {
00373                 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
00374                     states[*v] = BigInt_size_t( linearState % v->states() );
00375                     linearState /= v->states();
00376                 }
00377                 DAI_ASSERT( linearState == 0 );
00378             }
00379         }
00380 
00382         State( const std::map<Var, size_t> &s ) : state(0), states() {
00383             insert( s.begin(), s.end() );
00384         }
00385 
00387         typedef states_type::const_iterator const_iterator;
00388 
00390         const_iterator begin() const { return states.begin(); }
00391 
00393         const_iterator end() const { return states.end(); }
00394 
00396         operator size_t() const {
00397             DAI_ASSERT( valid() );
00398             return( BigInt_size_t( state ) );
00399         }
00400 
00402         template<typename InputIterator>
00403         void insert( InputIterator b, InputIterator e ) {
00404             states.insert( b, e );
00405             VarSet vars;
00406             for( const_iterator it = begin(); it != end(); it++ )
00407                 vars |= it->first;
00408             state = 0;
00409             state = this->operator()( vars );
00410         }
00411 
00413         const std::map<Var,size_t>& get() const { return states; }
00414 
00416         operator const std::map<Var,size_t>& () const { return states; }
00417 
00419         size_t operator() ( const Var &v ) const {
00420             states_type::const_iterator entry = states.find( v );
00421             if( entry == states.end() )
00422                 return 0;
00423             else
00424                 return entry->second;
00425         }
00426 
00428         BigInt operator() ( const VarSet &vs ) const {
00429             BigInt vs_state = 0;
00430             BigInt prod = 1;
00431             for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
00432                 states_type::const_iterator entry = states.find( *v );
00433                 if( entry != states.end() )
00434                     vs_state += entry->second * prod;
00435                 prod *= v->states();
00436             }
00437             return vs_state;
00438         }
00439 
00441         void operator++( ) {
00442             if( valid() ) {
00443                 state++;
00444                 states_type::iterator entry = states.begin();
00445                 while( entry != states.end() ) {
00446                     if( ++(entry->second) < entry->first.states() )
00447                         break;
00448                     entry->second = 0;
00449                     entry++;
00450                 }
00451                 if( entry == states.end() )
00452                     state = -1;
00453             }
00454         }
00455 
00457         void operator++( int ) {
00458             operator++();
00459         }
00460 
00462         bool valid() const {
00463             return( state >= 0 );
00464         }
00465 
00467         void reset() {
00468             state = 0;
00469             for( states_type::iterator s = states.begin(); s != states.end(); s++ )
00470                 s->second = 0;
00471         }
00472 };
00473 
00474 
00475 } // end of namespace dai
00476 
00477 
00488 #endif