[Subject Prev][Subject Next][Thread Prev][Thread Next][Date Index][Thread Index]

[hts-users:01795] Fast & memory efficient HHEd clustering patch


Hi all,

# I'm no longer a maintainer of HTS,
# but still doing some work to improve it :-)

As you know, one of the computationally and memory expensive parts of
HTS training is decision tree-based clustering.  Especially with a large
database, HHEd default implementation of decision tree-based clustering
consumes huge memory.

To avoid this problem, reduced-memory version of decision tree-based
clustering has been included since HTS-2.0.  Unfortunately, compared
with the HHEd default implementation, the reduced-memory one is *slow*.
This is because the reduced-memory one performs pattern matching
between model name and question every time.  We need a new solution
while keeping both memory and computational cost low.

To achieve this, I added a new reduced-memory implementation to HHEd.
Please check the attached patch code.  Once you apply this patch code,
you can choose memory reduction strategy by -r option as follows:

 -r n    reduce memory usage on clustering        default: n=0
          0: no memory reduction
          1: mid reduction but fast
          2: large reduction but slow

"-r 0" corresponds to HHEd default implementation.  In this case, each
question has a list of models matching this question.  These lists
consume huge memory when the size of database is large.  "-r 2"
corresponds to the previous reduced-memory implementation.  It doesn't
store any information but needs to perform pattern matching every time.
"-r 1" corresponds to the new reduced-memory implementation.  It makes a
0/1 bit table (its size is # of questions times # of models) to store
question/model matching results.  Because it performs pattern matching
just once to create this table, it's fast.  Furthermore, it stores the
matching results as a binary table rather than storing model lists.
Therefore, it consumes less memory (this is not guaranteed, but in many
cases memory consumption of the new one is much smaller than that of the
default one).

Another good point of this implementation is that we can easily predict
how much memory is additionally required to store this table.  For
example, if we have 2 million context-dependent models in a model set
and 4,000 questions (quite large model and question sets), this table
consumes 2,000,000 x 4,000 bits = 1 GBytes.

When I performed clustering of state duration models using
HTS-demo_CMU-ARCTIC-SLT+STRAIGHT on a Xeon 2.6GHz machine (single
thread), it consumed and took the following memory and time, respectively:

 -r 0 -> 245 MBytes  7min+45sec (HHEd default)
 -r 1 ->  57 MBytes  5min+10sec (new one)
 -r 2 ->  50 MBytes 15min+29sec (previous reduced memory implementation)

As you can see, the new one consumed slightly more memory than the
previous one (much less memory than the default one) but much faster.
In this case # of models was around 40,000 and # of questions was 1,500
(its predicted additional memory consumption is 7.5MBytes).  If you have
much more models and questions, the difference of memory usage between
"-r 1" and "-r 2" becomes large.  Interestingly, in this case, the new
one was faster than the HHEd default implementation!

This patch code may have bugs.  Please test it and check whether it
works or not.  If you find any problems in this, please report them to
the hts-users ML.  I hope this modification improves the scalability of
HTS.  It is my pleasure that this patch code contributes the further
progress of both HTS itself and HTS community.

Best regards,

Heiga ZEN (Byung Ha CHUN)

--
Heiga ZEN (Byung Ha CHUN)
Speech Technology Group
Cambridge Research Lab
Toshiba Research Europe
phone: +44 1223 436975




______________________________________________________________________
This email has been scanned by the MessageLabs Email Security System.
For more information please visit http://www.messagelabs.com/email ______________________________________________________________________
diff -bBEwur HTKLib/HModel.c HTKLib/HModel.c
--- HTKLib/HModel.c	2008-06-24 03:19:08.000000000 +0000
+++ HTKLib/HModel.c	2008-11-18 16:43:56.990558000 +0000
@@ -2296,7 +2296,7 @@
       HMError(src,"NumStates not set");
       return(FAIL);
    }
-   hmm->numStates = N = nState;
+   hmm->numStates = N = nState;  hmm->hIdx=0;
    se = (StateElem *)New(hset->hmem,(N-2)*sizeof(StateElem));
    hmm->svec = se - 2;
    while (tok->sym == STATE) {
@@ -4097,7 +4097,7 @@
                dset=*hset;
                dset.hmem=&gstack;
                dhmm = (HLink) New(&gstack,sizeof(HMMDef));
-               dhmm->owner=NULL; dhmm->numStates=0; dhmm->nUse=0; dhmm->hook=NULL;
+               dhmm->owner=NULL; dhmm->numStates=0; dhmm->nUse=0; dhmm->hook=NULL; dhmm->hIdx=0;
                if (trace&T_MAC)
                   printf("HModel: skipping HMM Def from macro %s\n",id->name);
                if (GetToken(&src,&tok)<SUCCESS) {
@@ -4249,11 +4249,11 @@
    StreamInfo *sti;
    MixPDF *mp;
    MLink m;
-   int h,nm,nsm,ns,nss,nsp,np,nt;
+   int h,nm,nsm,ns,nss,nsp,np,nh,nt;
    
    /* Reset indexes */
    hset->indexSet = TRUE;
-   nt=0;
+   nt=nh=0;
    NewHMMScan(hset,&hss);
    while (GoNextState(&hss,FALSE))
       hss.si->sIdx=-1;
@@ -4280,6 +4280,7 @@
          TouchV(hss.hmm->transP);
       }
       hss.hmm->tIdx=(int)((long)GetHook(hss.hmm->transP));
+      hss.hmm->hIdx=nh++;
    } while (GoNextHMM(&hss));
    EndHMMScan(&hss);
    NewHMMScan(hset,&hss);
diff -bBEwur HTKLib/HModel.h HTKLib/HModel.h
--- HTKLib/HModel.h	2008-06-19 04:31:41.000000000 +0000
+++ HTKLib/HModel.h	2008-11-18 13:34:23.412879000 +0000
@@ -201,6 +201,7 @@
    SVector dur;            /* vector of model duration params, if any */   
    SMatrix transP;         /* transition matrix (logs) */
    int tIdx;               /* Transition matrix index */
+   int hIdx;               /* hmm index */
    int nUse;               /* num logical hmm's sharing this def */
    Ptr hook;               /* general hook */
 } HMMDef;
--- HTKTools/HHEd.c	2008-08-07 13:36:19.341080000 +0000
+++ HTKTools/HHEd.c	2008-11-19 10:47:16.229385000 +0000
@@ -191,7 +191,7 @@
 static float MDLfactor = 1.0;                 /* MDL control factor */
 static Boolean ignoreStrW = FALSE;            /* ignore stream weight */
 static float minVar = 1.0E-6;                 /* minimum variance for clusterd item */
-static Boolean reduceMem = FALSE;             /* reduce memory requirement for decision-tree clustering */
+static int reduceMem = 0;                     /* reduce memory requirement for decision-tree clustering (0: no reduction, 1: mid reduction but fast, 2: large reduction but slow */
 static float minLeafOcc = 0.0;                /* minimum occ for each leaf node */
 static float minMixOcc = 0.0;                 /* minimum occ for each mix */
 static Vector shrinkOccThresh=NULL;           /* occupancy threshold for shrinking decision trees */
@@ -216,7 +216,7 @@
       if (GetConfBool(cParm,nParm,"SINGLETREE",&b)) singleTree = b;
       if (GetConfBool(cParm,nParm,"APPLYMDL",&b)) applyMDL = b;
       if (GetConfBool(cParm,nParm,"IGNORESTRW",&b)) ignoreStrW = b;
-      if (GetConfBool(cParm,nParm,"REDUCEMEM",&b)) reduceMem = b;
+      if (GetConfInt(cParm,nParm,"REDUCEMEM",&i)) reduceMem = i;
       if (GetConfFlt(cParm,nParm,"MINVAR",&f)) minVar = f;
       if (GetConfFlt(cParm,nParm,"MDLFACTOR",&f)) MDLfactor = f;
       if (GetConfFlt(cParm,nParm,"MINLEAFOCC",&f)) minLeafOcc = f;
@@ -306,7 +306,10 @@
    printf(" -m      apply MDL principle for clustering                off\n");
    printf(" -o s    extension for new hmm files          as source\n");
    printf(" -p      use pattern instead of base phone                 off\n");
-   printf(" -r      reduce memory usage on clustering                 off\n");
+   printf(" -r n    reduce memory usage on clustering                 0\n");
+   printf("          0: no memory reduction                           \n");
+   printf("          1: mid reduction but fast                        \n");
+   printf("          2: large reduction but slow                      \n");
    printf(" -s      construct single tree                             off\n");
    printf(" -v f    Set minimum variance to f                         1.0E-6\n");
    printf(" -w mmf  Save all HMMs to macro file mmf s    as source\n");
@@ -363,7 +366,7 @@
       case 'p':
          usePattern = TRUE; break;
       case 'r':
-         reduceMem = TRUE; break;
+         reduceMem = GetChkedInt(0,2,s); break;
       case 's':
          singleTree = TRUE; break;
       case 'v':
@@ -475,6 +478,10 @@
 
 /* ------------- Question Handling for Tree Building ----------- */
 
+static int nQuestions=0;
+static char **QMTable=NULL;
+static Boolean setQMTable=FALSE;
+
 typedef struct _IPat{
    char *pat;
    struct _IPat *next;
@@ -485,6 +492,7 @@
    IPat *patList;               
    ILink ilist;
    char pattern[PAT_LEN];
+   int index;
    Boolean used;
 } Question;
 
@@ -599,7 +607,7 @@
    }
    
    q=(Question *) New(&questHeap,sizeof(Question));
-   q->used=FALSE; q->qName=labid; q->patList = NULL;
+   q->used=FALSE; q->qName=labid; q->patList=NULL; q->index=nQuestions++;
    q->ilist = ilist;
    labid->aux=q; 
    strcpy(q->pattern, pattern);
@@ -3198,25 +3204,77 @@
    return(prob);
 }
 
+/* SetQMTable: Set Question x Model bit table */
+void SetQMTable (void) 
+{
+   char *name;
+   int i,j;
+   ILink p;
+   Boolean answer;
+   Question *q;
+   HMMScanState hss;
+
+   const int size=(hset->numPhyHMM>>3)+1;
+
+   QMTable = (char **) New(&questHeap, nQuestions*sizeof(char *));
+   
+   for (i=0; i<nQuestions; i++) {
+      QMTable[i] = (char *) New(&questHeap, size*sizeof(char));
+      memset(QMTable[i], 0, size);
+   }
+
+   NewHMMScan(hset,&hss);
+   do {
+      /* index and name of current hmm */
+      i = hss.hmm->hIdx >> 3;  /* division by 8 */
+      j = hss.hmm->hIdx &  7;  /* residue by 8 */
+      name = hss.mac->id->name;
+
+      /* check answers */
+      for (p=qList; p!=NULL; p=p->next) {
+         q = (Question *)p->item;
+         answer = QMatch(name, q);
+         if (answer)
+            QMTable[q->index][i] |= (1<<j);  /* set j-th bit */
+      }
+   } while(GoNextHMM(&hss));
+   EndHMMScan(&hss);
+
+   setQMTable = TRUE;
+
+   return;
+}
+
 /* AnswerQuestion: set ans field in each cluster item in preparation for
    a possible split */
 Boolean AnswerQuestion(Node *node, Question *q)
 {
    CLink p;
    ILink i;
-   int yes=0, no=0;
-  
-   if (reduceMem) {
       MLink m;
+   int yes=0, no=0, index, j;
       
+   switch(reduceMem) {
+   case 2:
       for (p=node->clist;p!=NULL;p=p->next) {
          m = (MLink)p->item->owner->hook;
          p->ans = QMatch(m->id->name, q);
          if (p->ans) yes++;
          else        no++;
       }
+      break;
+   case 1:
+      index = q->index;
+      if (!setQMTable) SetQMTable();
+      for (p=node->clist; p!=NULL; p=p->next) {
+         j = p->item->owner->hIdx & 7;  /* residue by 8 */
+         p->ans = (QMTable[index][p->item->owner->hIdx>>3] & (1<<j)) ? TRUE : FALSE;
+         if (p->ans) yes++;
+         else        no++;
    }
-   else {
+      break;
+   case 0:
+   default: 
       /* set ans=FALSE for items in this cluster */
       for (p=node->clist;p!=NULL;p=p->next) {
       p->ans = FALSE;
@@ -6244,7 +6301,7 @@
    ChkedAlpha("QS question name",qName);
    
    /* get copy of original item list whilst parsing it */
-   if (reduceMem) {
+   if (reduceMem!=0) {
       char pattern[PAT_LEN]; 
       ReadLine(&source, pattern);
       if (trace & T_QST) {
@@ -6252,6 +6309,7 @@
          fflush(stdout);
       }
       LoadQuestion(qName,NULL,pattern);
+      setQMTable=FALSE;
    }
    else {
       ILink ilist=NULL;