/*
    This file is part of nncore.
    
    This code is written by Davide Albanese, <albanese@fbk.it>.
    (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/


#include <Python.h>
#include <numpy/arrayobject.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "numpysupport.h"
#include "nn.h"

  
/* Predict NN */
static PyObject *nncore_predictnn(PyObject *self, PyObject *args, PyObject *keywds)
{
  PyObject *x = NULL;      PyObject *xc = NULL;
  PyObject *y = NULL;      PyObject *yc = NULL;
  PyObject *sample  = NULL; PyObject *samplec = NULL;
  PyObject *classes = NULL; PyObject *classesc = NULL;

  int k, dist;
  int i;

  /* Parse Tuple*/
  static char *kwlist[] = {"x", "y", "sample", "classes", "k", "dist", NULL};

  if (!PyArg_ParseTupleAndKeywords(args, keywds, "OOOOii", kwlist, 
				   &x, &y, &sample, &classes, &k, &dist))
    return NULL;
  
  xc = PyArray_FROM_OTF(x, NPY_DOUBLE, NPY_IN_ARRAY);
  if (xc == NULL) return NULL;

  yc = PyArray_FROM_OTF(y, NPY_LONG, NPY_IN_ARRAY);
  if (yc == NULL) return NULL;
  
  samplec = PyArray_FROM_OTF(sample, NPY_DOUBLE, NPY_IN_ARRAY);
  if (samplec == NULL) return NULL;
  
  classesc = PyArray_FROM_OTF(classes, NPY_LONG, NPY_IN_ARRAY);
  if (classesc == NULL) return NULL;

  /* Check size */
  if (PyArray_DIM(yc, 0) != PyArray_DIM(xc, 0)){
    PyErr_SetString(PyExc_ValueError, "y array has wrong 0-dimension");
    return NULL;
  }
  
  if (PyArray_DIM(samplec, 0) != PyArray_DIM(xc, 1)){
    PyErr_SetString(PyExc_ValueError, "sample array has wrong 0-dimension");
    return NULL;
  }

  int n            = (int) PyArray_DIM(xc, 0);
  int d            = (int) PyArray_DIM(xc, 1);
  double **_x      = dmatrix_from_numpy(xc);
  long *_ytmp      = (long *) PyArray_DATA(yc);
  double *_sample  = (double *) PyArray_DATA(samplec);
  long *_classestmp = (long *) PyArray_DATA(classesc);
  int nclasses     = (int) PyArray_DIM(classesc, 0);
  double *margin;

  int *_y = (int *) malloc(n * sizeof(int));
  for(i=0; i<n; i++)
    _y[i] = (int) _ytmp[i];

  int *_classes = (int *) malloc(nclasses * sizeof(int));
  for(i=0; i<nclasses; i++)
    _classes[i] = (int) _classestmp[i];
   
  NearestNeighbor nn;
  nn.n = n;
  nn.d = d;
  nn.x = _x;
  nn.y = _y;
  nn.classes = _classes;
  nn.nclasses = nclasses;
  nn.k = k;
  nn.dist = dist;
  
  int pred = predict_nn(&nn, _sample, &margin);
  
  free(_x);
  free(_y);
  free(_classes);
  free(margin);
  
  Py_DECREF(xc);
  Py_DECREF(yc);
  Py_DECREF(samplec);
  Py_DECREF(classesc);
  
  return Py_BuildValue("i", pred);
}


static char nncore_predictnn_doc[] = "Predict NN. Return the prediction (an integer)";
static char module_doc[] = "NN module based on NN C-libraries developed by Stefano Merler";

/* Method table */
static PyMethodDef nncore_methods[] = {
  {"predictnn",
   (PyCFunction)nncore_predictnn,
   METH_VARARGS | METH_KEYWORDS,
   nncore_predictnn_doc},
  {NULL, NULL, 0, NULL}
};


/* Init */
void initnncore()
{
  Py_InitModule3("nncore", nncore_methods, module_doc);
  import_array();
}
