trie.h

Go to the documentation of this file.
00001 
00023 #ifndef _TRIE_H
00024 #define _TRIE_H
00025 
00026 #include <assert.h>
00027 #include "defines.h"
00028 
00029 using namespace std;
00030 
00034 template<typename _type>
00035 class Trie {
00036         private:
00040                 class Node {
00041                         public:
00043 
00045                                 const _type value;
00046                                 const Node* parent;
00047                                 unsigned long frequency;
00048                                 list<Node> childs;
00050 
00051                                 // this is only used for initialzation in the Trie constructor
00053                                 Node() : value(0), parent(NULL) {
00054                                         frequency = 1;
00055                                 }
00056 
00058                                 Node(const _type val, const Node* parent) : value(val), parent(parent) {
00059                                         frequency = 1;
00060                                 }
00061 
00063                                 string toString(unsigned int level) const {
00064                                         string ret, tab;
00065                                         char buf[500];
00066                                         for (unsigned int i=0; i<level; i++)
00067                                                 tab += "  ";
00068                                         sprintf(buf, "%s%ld (%c): %ld\n", tab.c_str(), (long) value, (char) value, frequency);
00069                                         ret += buf;
00070                                         for (typename list<Node>::const_iterator node = childs.begin(); node != childs.end(); node++) {
00071                                                 ret += (*node).toString(level+1);
00072                                         }
00073                                         return ret;
00074                                 }
00075                 };
00076 
00078                 Node root;
00079 
00080         public:
00082                 Trie();
00084                 ~Trie();
00085 
00087                 bool findSequence(const list<_type>* sequence) const;
00088                 // this also updates the frequencies if updateFrequency is true
00090                 void insertSequence(const list<_type>* sequence, bool updateFrequency=false);
00092                 double calculateProbability(const list<_type>* context, const _type nextValue) const;
00093 
00094                 // a human-readable output
00096                 string toString() const;
00097                 // an output in text format, but which can be read be unserialize
00099                 string serialize() const;
00101                 void unserialize(string data);
00102 
00103         private:
00104                 // a helper function used by both findSequence and calculateProbability
00105                 // returns NULL if the sequence was not found
00107                 const Node* navigateTree(const list<_type>* sequence) const
00108                 {
00109                         const Node* cur = &root;
00110                         for (typename list<_type>::const_iterator symbol = sequence->begin(); symbol != sequence->end(); symbol++) {
00111                                 bool foundsymb = false;
00112                                 for (typename list<Node>::const_iterator node = cur->childs.begin(); !foundsymb && node != cur->childs.end(); node++) {
00113                                         if ((*node).value == *symbol) {
00114                                                 // found that symbol in the sequence at the correct trie position, go to the next level
00115                                                 foundsymb = true;
00116                                                 cur = &(*node);
00117                                         }
00118                                 }
00119                                 if (!foundsymb)
00120                                         // not found --> the sequence is not in the trie
00121                                         return NULL;
00122                         }
00123                         // all symbols in the sequence found, return the last node
00124                         return cur;
00125                 }
00126 };
00127 
00128 template<typename _type>
00129 Trie<_type>::Trie() {}
00130 
00131 template<typename _type>
00132 Trie<_type>::~Trie() {}
00133 
00134 template<typename _type>
00135 bool Trie<_type>::findSequence(const list<_type>* sequence) const
00136 {
00137         return (navigateTree(sequence) != NULL);
00138 }
00139 
00140 template<typename _type>
00141 void Trie<_type>::insertSequence(const list<_type>* sequence, bool updateFrequency)
00142 {
00143         Node* cur = &root;
00144         for (typename list<_type>::const_iterator symbol = sequence->begin(); symbol != sequence->end(); symbol++) {
00145                 bool foundsymb = false;
00146                 for (typename list<Node>::iterator node = cur->childs.begin(); !foundsymb && node != cur->childs.end(); node++) {
00147                         if ((*node).value == *symbol) {
00148                                 foundsymb = true;
00149                                 cur = &(*node);
00150                                 // and update the frequency of that symbol in that level - this is the old behavior for Lempel Ziv
00151                                 //(*node).frequency++;
00152                         }
00153                 }
00154                 if (!foundsymb) {
00155                         // insert a new one at that level
00156                         Node newNode(*symbol, cur);
00157                         cur->childs.push_back(newNode);
00158                         cur = &newNode;
00159                 }
00160         }
00161         // increase the frequency of the last symbol seen in here
00162         if (updateFrequency)
00163                 (*cur).frequency++;
00164 }
00165 
00166 template<typename _type>
00167 double Trie<_type>::calculateProbability(const list<_type>* context, const _type nextValue) const
00168 {
00169 //      printf("++++\n");
00170         double prob = 0.0, escapeProb = 1.0;
00171         unsigned long sumfreq, foundfreq;
00172         // to calculate the probability of some next symbol, we first need to navigate to the given context
00173         const Node* cur = navigateTree(context);
00174         assert(cur != NULL);
00175         /* but it is not allowed to have no childs at the node we are starting
00176            from (because then the inner loop would not run, sumfreq and
00177            foundfreq would be 0 and we would finally exit with probability 0) */
00178         if (cur->childs.size() == 0)
00179                 cur = cur->parent;
00180 
00181         while (cur != &root && cur != NULL) {
00182 //              printf("++++ at %d (%c): ", cur->value, cur->value);
00183                 sumfreq = foundfreq = 0;
00184                 // also count the sum of frequencies
00185                 for (typename list<Node>::const_iterator node = cur->childs.begin(); node != cur->childs.end(); node++) {
00186 //                      printf("%c ", (*node).value);
00187                         sumfreq += (*node).frequency;
00188                         // this can only happen once (because of the insertion procedure)
00189                         if ((*node).value == nextValue) {
00190                                 assert(foundfreq == 0);
00191                                 foundfreq = (*node).frequency;
00192                         }
00193                 }
00194                 prob += escapeProb * ((double) foundfreq / cur->frequency);
00195                 escapeProb *= ((double) sumfreq / cur->frequency);
00196 //              printf("curProb=%f, curEscapeProb=%f, prob=%f, escapeProb=%f\n", escapeProb * ((double) foundfreq / cur->frequency), ((double) sumfreq / cur->frequency), prob, escapeProb);
00197                 cur = cur->parent;
00198         }
00199         sumfreq = foundfreq = 0;
00200         for (typename list<Node>::const_iterator node = cur->childs.begin(); node != cur->childs.end(); node++) {
00201                 sumfreq += (*node).frequency;
00202                 // this can only happen once (because of the insertion procedure)
00203                 if ((*node).value == nextValue) {
00204                         assert(foundfreq == 0);
00205                         foundfreq = (*node).frequency;
00206                 }
00207         }
00208         prob += escapeProb *((double) foundfreq / sumfreq);
00209         return prob;
00210 }
00211 
00212 template<typename _type>
00213 string Trie<_type>::toString() const
00214 {
00215         return root.toString(0);
00216 }
00217 
00218 template<typename _type>
00219 string Trie<_type>::serialize() const
00220 {
00221         throw 1;
00222         return "";
00223 }
00224 
00225 template<typename _type>
00226 void Trie<_type>::unserialize(string data)
00227 {
00228         throw 1;
00229 }
00230 
00231 #endif

Generated on Mon Jun 5 10:20:48 2006 for Intelligence.kdevelop by  doxygen 1.4.6