libDAI
|
00001 /* This file is part of libDAI - http://www.libdai.org/ 00002 * 00003 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved. 00004 * 00005 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. 00006 */ 00007 00008 00011 00012 00013 #ifndef __defined_libdai_index_h 00014 #define __defined_libdai_index_h 00015 00016 00017 #include <vector> 00018 #include <algorithm> 00019 #include <map> 00020 #include <dai/varset.h> 00021 00022 00023 namespace dai { 00024 00025 00027 00048 class IndexFor { 00049 private: 00051 long _index; 00052 00054 std::vector<long> _sum; 00055 00057 std::vector<size_t> _state; 00058 00060 std::vector<size_t> _ranges; 00061 00062 public: 00064 IndexFor() : _index(-1) {} 00065 00067 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _state( forVars.size(), 0 ) { 00068 long sum = 1; 00069 00070 _ranges.reserve( forVars.size() ); 00071 _sum.reserve( forVars.size() ); 00072 00073 VarSet::const_iterator j = forVars.begin(); 00074 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) { 00075 for( ; j != forVars.end() && *j <= *i; ++j ) { 00076 _ranges.push_back( j->states() ); 00077 _sum.push_back( (*i == *j) ? sum : 0 ); 00078 } 00079 sum *= i->states(); 00080 } 00081 for( ; j != forVars.end(); ++j ) { 00082 _ranges.push_back( j->states() ); 00083 _sum.push_back( 0 ); 00084 } 00085 _index = 0; 00086 } 00087 00089 IndexFor& reset() { 00090 fill( _state.begin(), _state.end(), 0 ); 00091 _index = 0; 00092 return( *this ); 00093 } 00094 00096 operator size_t() const { 00097 DAI_ASSERT( valid() ); 00098 return( _index ); 00099 } 00100 00102 IndexFor& operator++ () { 00103 if( _index >= 0 ) { 00104 size_t i = 0; 00105 00106 while( i < _state.size() ) { 00107 _index += _sum[i]; 00108 if( ++_state[i] < _ranges[i] ) 00109 break; 00110 _index -= _sum[i] * _ranges[i]; 00111 _state[i] = 0; 00112 i++; 00113 } 00114 00115 if( i == _state.size() ) 00116 _index = -1; 00117 } 00118 return( *this ); 00119 } 00120 00122 void operator++( int ) { 00123 operator++(); 00124 } 00125 00127 bool valid() const { 00128 return( _index >= 0 ); 00129 } 00130 }; 00131 00132 00134 00137 class Permute { 00138 private: 00140 std::vector<size_t> _ranges; 00142 std::vector<size_t> _sigma; 00143 00144 public: 00146 Permute() : _ranges(), _sigma() {} 00147 00149 Permute( const std::vector<size_t> &rs, const std::vector<size_t> &sigma ) : _ranges(rs), _sigma(sigma) { 00150 DAI_ASSERT( _ranges.size() == _sigma.size() ); 00151 } 00152 00154 00158 Permute( const std::vector<Var> &vars, bool reverse=false ) : _ranges(), _sigma() { 00159 size_t N = vars.size(); 00160 00161 // construct ranges 00162 _ranges.reserve( N ); 00163 for( size_t i = 0; i < N; ++i ) 00164 if( reverse ) 00165 _ranges.push_back( vars[N - 1 - i].states() ); 00166 else 00167 _ranges.push_back( vars[i].states() ); 00168 00169 // construct VarSet out of vars 00170 VarSet vs( vars.begin(), vars.end(), N ); 00171 DAI_ASSERT( vs.size() == N ); 00172 00173 // construct sigma 00174 _sigma.reserve( N ); 00175 for( VarSet::const_iterator vs_i = vs.begin(); vs_i != vs.end(); ++vs_i ) { 00176 size_t ind = find( vars.begin(), vars.end(), *vs_i ) - vars.begin(); 00177 if( reverse ) 00178 _sigma.push_back( N - 1 - ind ); 00179 else 00180 _sigma.push_back( ind ); 00181 } 00182 } 00183 00185 00188 size_t convertLinearIndex( size_t li ) const { 00189 size_t N = _ranges.size(); 00190 00191 // calculate vector index corresponding to linear index 00192 std::vector<size_t> vi; 00193 vi.reserve( N ); 00194 size_t prod = 1; 00195 for( size_t k = 0; k < N; k++ ) { 00196 vi.push_back( li % _ranges[k] ); 00197 li /= _ranges[k]; 00198 prod *= _ranges[k]; 00199 } 00200 00201 // convert permuted vector index to corresponding linear index 00202 prod = 1; 00203 size_t sigma_li = 0; 00204 for( size_t k = 0; k < N; k++ ) { 00205 sigma_li += vi[_sigma[k]] * prod; 00206 prod *= _ranges[_sigma[k]]; 00207 } 00208 00209 return sigma_li; 00210 } 00211 00213 const std::vector<size_t>& sigma() const { return _sigma; } 00214 00216 std::vector<size_t>& sigma() { return _sigma; } 00217 00219 const std::vector<size_t>& ranges() { return _ranges; } 00220 00222 size_t operator[]( size_t i ) const { 00223 #ifdef DAI_DEBUG 00224 return _sigma.at(i); 00225 #else 00226 return _sigma[i]; 00227 #endif 00228 } 00229 00231 Permute inverse() const { 00232 size_t N = _ranges.size(); 00233 std::vector<size_t> invRanges( N, 0 ); 00234 std::vector<size_t> invSigma( N, 0 ); 00235 for( size_t i = 0; i < N; i++ ) { 00236 invSigma[_sigma[i]] = i; 00237 invRanges[i] = _ranges[_sigma[i]]; 00238 } 00239 return Permute( invRanges, invSigma ); 00240 } 00241 }; 00242 00243 00245 00263 class multifor { 00264 private: 00266 std::vector<size_t> _ranges; 00268 std::vector<size_t> _indices; 00270 long _linear_index; 00271 00272 public: 00274 multifor() : _ranges(), _indices(), _linear_index(0) {} 00275 00277 multifor( const std::vector<size_t> &d ) : _ranges(d), _indices(d.size(),0), _linear_index(0) {} 00278 00280 operator size_t() const { 00281 DAI_DEBASSERT( valid() ); 00282 return( _linear_index ); 00283 } 00284 00286 size_t operator[]( size_t k ) const { 00287 DAI_DEBASSERT( valid() ); 00288 DAI_DEBASSERT( k < _indices.size() ); 00289 return _indices[k]; 00290 } 00291 00293 multifor & operator++() { 00294 if( valid() ) { 00295 _linear_index++; 00296 size_t i; 00297 for( i = 0; i != _indices.size(); i++ ) { 00298 if( ++(_indices[i]) < _ranges[i] ) 00299 break; 00300 _indices[i] = 0; 00301 } 00302 if( i == _indices.size() ) 00303 _linear_index = -1; 00304 } 00305 return *this; 00306 } 00307 00309 void operator++( int ) { 00310 operator++(); 00311 } 00312 00314 multifor& reset() { 00315 fill( _indices.begin(), _indices.end(), 0 ); 00316 _linear_index = 0; 00317 return( *this ); 00318 } 00319 00321 bool valid() const { 00322 return( _linear_index >= 0 ); 00323 } 00324 }; 00325 00326 00328 00352 class State { 00353 private: 00355 typedef std::map<Var, size_t> states_type; 00356 00358 BigInt state; 00359 00361 states_type states; 00362 00363 public: 00365 State() : state(0), states() {} 00366 00368 State( const VarSet &vs, BigInt linearState=0 ) : state(linearState), states() { 00369 if( linearState == 0 ) 00370 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) 00371 states[*v] = 0; 00372 else { 00373 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) { 00374 states[*v] = BigInt_size_t( linearState % v->states() ); 00375 linearState /= v->states(); 00376 } 00377 DAI_ASSERT( linearState == 0 ); 00378 } 00379 } 00380 00382 State( const std::map<Var, size_t> &s ) : state(0), states() { 00383 insert( s.begin(), s.end() ); 00384 } 00385 00387 typedef states_type::const_iterator const_iterator; 00388 00390 const_iterator begin() const { return states.begin(); } 00391 00393 const_iterator end() const { return states.end(); } 00394 00396 operator size_t() const { 00397 DAI_ASSERT( valid() ); 00398 return( BigInt_size_t( state ) ); 00399 } 00400 00402 template<typename InputIterator> 00403 void insert( InputIterator b, InputIterator e ) { 00404 states.insert( b, e ); 00405 VarSet vars; 00406 for( const_iterator it = begin(); it != end(); it++ ) 00407 vars |= it->first; 00408 state = 0; 00409 state = this->operator()( vars ); 00410 } 00411 00413 const std::map<Var,size_t>& get() const { return states; } 00414 00416 operator const std::map<Var,size_t>& () const { return states; } 00417 00419 size_t operator() ( const Var &v ) const { 00420 states_type::const_iterator entry = states.find( v ); 00421 if( entry == states.end() ) 00422 return 0; 00423 else 00424 return entry->second; 00425 } 00426 00428 BigInt operator() ( const VarSet &vs ) const { 00429 BigInt vs_state = 0; 00430 BigInt prod = 1; 00431 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) { 00432 states_type::const_iterator entry = states.find( *v ); 00433 if( entry != states.end() ) 00434 vs_state += entry->second * prod; 00435 prod *= v->states(); 00436 } 00437 return vs_state; 00438 } 00439 00441 void operator++( ) { 00442 if( valid() ) { 00443 state++; 00444 states_type::iterator entry = states.begin(); 00445 while( entry != states.end() ) { 00446 if( ++(entry->second) < entry->first.states() ) 00447 break; 00448 entry->second = 0; 00449 entry++; 00450 } 00451 if( entry == states.end() ) 00452 state = -1; 00453 } 00454 } 00455 00457 void operator++( int ) { 00458 operator++(); 00459 } 00460 00462 bool valid() const { 00463 return( state >= 0 ); 00464 } 00465 00467 void reset() { 00468 state = 0; 00469 for( states_type::iterator s = states.begin(); s != states.end(); s++ ) 00470 s->second = 0; 00471 } 00472 }; 00473 00474 00475 } // end of namespace dai 00476 00477 00488 #endif