/*
 YamCha -- Yet Another Multipurpose CHunk Annotator

 $Id: svm_model.cc,v 1.15 2001/06/19 09:36:09 taku-ku Exp $;

 Copyright (C) 2001  Taku Kudoh <taku-ku.aist-nara.ac.jp>
 All rights reserved.

 This library is free software; you can redistribute it and/or
 modify it under the terms of the GNU Library General Public
 License as published by the Free Software Foundation; either
 version 2 of the License, or (at your option) any later verjsion.

 This library is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 Library General Public License for more details.

 You should have received a copy of the GNU Library General Public
 License along with this library; if not, write to the
 Free Software Foundation, Inc., 59 Temple Place - Suite 330,
 Boston, MA 02111-1307, USA.
*/
#include "common.h"
#include "svm_model.h"
#include "math.h"

// $Id: svm_model.cc,v 1.15 2001/06/19 09:36:09 taku-ku Exp $;

namespace YamCha {

svmModel::svmModel():
  dotBuf      (0),
  dotCache    (0),  
  _resultList (0),
  resultList  (0),
  readSize    (0),
  alphaList   (0),
  fi          (0),
  table       (0),
  doubleArraySize(0),
  doubleArray (0),
  modelList   (0),
  modelSize   (0), 
  classSize   (0),
  classList   (0) {}

svmModel::~svmModel()
{
  for (unsigned int i = 0; i < classSize; i++) delete [] classList[i];
  delete [] classList;
  delete [] alphaList;
  delete [] modelList;
  delete [] table;
  delete [] fi;
  delete [] dotBuf;
  delete [] dotCache;
  delete [] _resultList;
  delete [] resultList;
  delete [] doubleArray;
}

void svmModel::read_buf(int fd, void *ptr, size_t size)
{
  if ((unsigned int)read(fd, ptr, size) != size)  {
     throw string("svmModel::read(): read failed");
  }

  readSize += size;
}

int svmModel::readModel(const string &filename)
{
  int fd;
  unsigned int size;

  /* read text */
  if ((fd = open(filename.c_str(), O_RDONLY|O_BINARY)) < 0) {
    string tmp = ("svmModel::readModel(): " + filename + ": no such file or directory");
    throw tmp;
  }

  // get size of filename (not use 'stat', because of portability)
  lseek (fd,0,SEEK_END);
  size = lseek(fd,0,SEEK_CUR);
  lseek (fd,0,SEEK_SET);

  // kernel specfic param.
  read_buf (fd, version, sizeof (char) * sizeof (version));

  // check version
  if (atof(version) != MODEL_VERSION) {
    string tmp = "svmModel::readModel(): model version is different: ";
    throw tmp;
  }
   
  read_buf (fd, kernel_type,   sizeof (char) * sizeof (kernel_type) );
  read_buf (fd, &param_degree, sizeof (unsigned int) );
  read_buf (fd, &param_g,      sizeof (double) );
  read_buf (fd, &param_r,      sizeof (double) );
  read_buf (fd, &param_s,      sizeof (double) );

  // model specfic
  read_buf (fd, &modelSize,             sizeof (unsigned int) );
  read_buf (fd, &classSize,             sizeof (unsigned int) );
  read_buf (fd, &alphaSize,             sizeof (unsigned int) );
  read_buf (fd, &svSize,                sizeof (unsigned int) );
  read_buf (fd, &tableSize,             sizeof (unsigned int) );
  read_buf (fd, &dimensionSize,         sizeof (unsigned int) );
  read_buf (fd, &nonzeroDimensionSize,  sizeof (unsigned int) );

  // Double Array 
  read_buf (fd, &doubleArraySize, sizeof (unsigned int) );

  // read model prameters 
  int param_size;
  read_buf (fd, &param_size, sizeof (unsigned int) );
  char *param_str =  new char [param_size + 1];
  read_buf (fd, param_str, sizeof (char) * param_size);
  int pos = 0;
  while (pos < param_size) {
    char *key =  (param_str + pos);
    while (param_str[++pos] != '\0');
    pos++;
    char *value = param_str + pos;
    paramHash[string(key)] = string(value);
    while (param_str[++pos] != '\0');
    pos++;
  }
  delete [] param_str;

  // classList, list of fixied record (32)
  classList = new char * [classSize];
  for (unsigned int i = 0; i < classSize; i++) {
    classList[i] = new char [32];
    read_buf (fd, classList[i], sizeof (char) * 32);
    class2idHash[string(classList[i])] = i+1;
  }

  // model
  modelList = new _Model [modelSize];
  for (unsigned int i = 0; i < modelSize; i++) {
     int pos,neg;
     double b;
     read_buf(fd, &pos, sizeof (unsigned int) );
     read_buf(fd, &neg, sizeof (unsigned int) );
     read_buf(fd, &b,   sizeof (double) );
     modelList[i].pos = pos;
     modelList[i].neg = neg;
     modelList[i].b = b;
  }

  // alpha, tricky, including dummy filelds
  alphaList = new _Alpha[alphaSize + modelSize];
  for (unsigned int i = 0; i < alphaSize + modelSize; i++) {
     int id;
     double alpha;
     read_buf(fd, &id,    sizeof (int) );
     read_buf(fd, &alpha, sizeof (double) );
     alphaList[i].id = id;
     alphaList[i].alpha = alpha;
  }

  // feature index
  fi = new unsigned int [dimensionSize];
  read_buf (fd, fi, sizeof (unsigned int) * dimensionSize);

  // table
  table = new int [tableSize];
  read_buf (fd, table, sizeof (int) * tableSize);

  // Double Array
  doubleArraySize /= sizeof (_Unit);
  doubleArray = new _Unit  [doubleArraySize];
  read_buf (fd, doubleArray, sizeof (_Unit) * doubleArraySize);

  // close
  close(fd);
  
  // check size
  if (readSize != size) {
    throw string("svmModel::readModel(): size of model file seems to be invalid.");
  }

  // initilize
  dotCache = new double [nonzeroDimensionSize+1];
  for (unsigned int i = 0; i <= nonzeroDimensionSize; i++) 
    dotCache[i] = pow (param_s*i  + param_r, param_degree);

  dotBuf      = new unsigned int [svSize];
  _resultList = new double [modelSize];
 
  resultList  = new ModelResult [classSize];
  for (unsigned int i = 0; i < classSize; i++) {
    resultList[i].className = string(classList[i]);
    resultList[i].voteScore = 0;
    resultList[i].distScore = 0.0;
  }

  return 1;
}

ModelResult *svmModel::classify(char **featuresList, unsigned int featuresSize) 
{
  // initilize
  for (unsigned int i = 0; i < svSize; i++) dotBuf[i] = 0;
  for (unsigned int i = 0; i < modelSize; i++) _resultList[i] = -(modelList[i].b);
  for (unsigned int i = 0; i < classSize; i++) {
    resultList[i].voteScore = 0;
    resultList[i].distScore = 0.0;
  }

  for (unsigned int k = 0;;) {
  next:
    if (k == featuresSize) break;

    char *key = featuresList[k];
    unsigned int len = strlen (key);
    int b = doubleArray[0].base;
    unsigned int p;

    for (unsigned int i = 0; i < len; i++) {
      p = b + (unsigned char)key[i] + 1;
      if ((unsigned int)b == doubleArray[p].check) {
	b = doubleArray[p].base;
      } else {
	k++;
	goto next;
      }
    }

    p = b;
    int n = doubleArray[p].base;
    if ((unsigned int)b == doubleArray[p].check && n < 0)
      for (int j = fi[-n-1]; table[j] != -1; j++) dotBuf[table[j]]++;

    k++;
  }

  unsigned int i = 0;
  for (unsigned int j = 0;;j++) {
    if (alphaList[j].id == -1) {
      if (++i == modelSize) break;
    } else {
      _resultList[i] += alphaList[j].alpha * dotCache[dotBuf[alphaList[j].id]];
    }
  }
  
  for (unsigned int i = 0; i < modelSize; i++) {
    if (_resultList[i] > 0) {
      resultList[this->modelList[i].pos].voteScore++;
      resultList[this->modelList[i].pos].distScore += fabs (_resultList[i]);
    } else if (_resultList[i] < 0) {
      resultList[this->modelList[i].neg].voteScore++;
      resultList[this->modelList[i].neg].distScore += fabs (_resultList[i]);
    }

  }

  return resultList;
}
}
