// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "Multinomial.h"
#include "log_add.h"
#include "random.h"

namespace Torch {

Multinomial::Multinomial(int n_values_,real prior_weights_) : Distribution()
{
  n_observations = 1;
  n_inputs = 0;
  n_values = n_values_;
  prior_weights = prior_weights_;
  initial_params = NULL;
  initial_file = NULL;
  addOption("initial params",sizeof(List*),&initial_params,"initial params");
  addOption("initial file",sizeof(char*),&initial_file,"initial file");
}

void Multinomial::allocateMemory()
{
  max_n_frames = 1;
  n_params = numberOfParams();
  addToList(&params,n_params,(real*)xalloc(sizeof(real)*n_params));
  addToList(&der_params,n_params,(real*)xalloc(sizeof(real)*n_params));
  addToList(&outputs,n_outputs,(real*)xalloc(sizeof(real)*n_outputs));
  log_weights = (real*)params->ptr;
  dlog_weights = (real*)der_params->ptr;
  log_probabilities = (real*)xalloc(sizeof(real)*max_n_frames);
  weights_acc = (real*)xalloc(sizeof(real)*n_values);
}

void Multinomial::freeMemory()
{
  freeList(&outputs,true);
  freeList(&params,true);
  freeList(&der_params,true);
  free(log_probabilities);
  free(weights_acc);
}

int Multinomial::numberOfParams()
{
  return n_values;
}

void Multinomial::reset()
{
  // here, initialize the parameters somehow...

  if (initial_params) {
    copyList(params,initial_params);
  } else if (initial_file) {
    load(initial_file);
  } else {
    // initialize randomly the weights
    real sum = 0.;
    for (int i=0;i<n_values;i++) {
      log_weights[i] = bounded_uniform(0.1,1);
      sum += log_weights[i];
    }
    for (int i=0;i<n_values;i++) {
      log_weights[i] = log(log_weights[i]/sum);
    }
  }
}

void Multinomial::eMSequenceInitialize(List* inputs)
{
  if (!inputs)
    return;
  SeqExample* ex = (SeqExample*)inputs->ptr;
  if (ex->n_real_frames > max_n_frames) {
    max_n_frames = ex->n_real_frames;
    log_probabilities = (real*)xrealloc(log_probabilities,sizeof(real)*max_n_frames);
  }
}

void Multinomial::sequenceInitialize(List* inputs)
{
  // initialize the accumulators to 0 and compute pre-computed value
  eMSequenceInitialize(inputs);
  real *dlw = dlog_weights;
  for (int i=0;i<n_values;i++) {
    *dlw++ = 0;
  }
}

real Multinomial::frameLogProbability(real *observations, real *inputs, int t)
{
  int obs = (int)observations[0];
  real log_prob = log_weights[obs];
  log_probabilities[t] = log_prob;
  return log_prob;
}

void Multinomial::frameEMAccPosteriors(real *observations, real log_posterior, real *inputs, int t)
{
  real log_prob = log_probabilities[t];
  real *p_weights_acc = weights_acc;
  real *log_w_i = log_weights;
  for (int i=0;i<n_values;i++) {
    *p_weights_acc++ += exp(log_posterior + *log_w_i++ - log_prob);
  }
}

void Multinomial::eMUpdate()
{
  real* p_weights_acc = weights_acc;
  real sum_weights_acc = 0;
  for (int i=0;i<n_values;i++)
    sum_weights_acc += *p_weights_acc++;
  real *p_log_weights = log_weights;
  real log_sum = log(sum_weights_acc);
  p_weights_acc = weights_acc;
  for (int i=0;i<n_values;i++)
    *p_log_weights++ = log(*p_weights_acc++) - log_sum;
}

void Multinomial::eMIterInitialize()
{
  // initialize the accumulators to 0 and compute pre-computed value
  for (int i=0;i<n_values;i++) {
    weights_acc[i] = prior_weights;
  }
}

void Multinomial::iterInitialize()
{
}

void Multinomial::frameBackward(real *observations, real *alpha, real *inputs, int t)
{
  real log_prob = log_probabilities[t];
  real *lw = log_weights;
  real* dlw = dlog_weights;
  for (int i=0;i<n_values;i++,lw++) {
    real post_i =  - *alpha * exp(*lw + - log_prob);
    *dlw++ += post_i;
    real *dlw2 = dlog_weights;
    real *lw2 = log_weights;
    for (int j=0;j<n_values;j++)
      *dlw2++ -= post_i * exp(*lw2++);
  }
}

void Multinomial::frameExpectation(real *observations, real *inputs, int t)
{
  real* obs = observations;
  *obs = 0;
  real *lw = log_weights;
  for (int i=0;i<n_values;i++) {
    *obs += exp(*lw++);
  }
  *obs /= (real)n_values;
}


Multinomial::~Multinomial()
{
  freeMemory();
}

}

