#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>
#include "nm.h"

void printV3D(VECT3D* v) {
    printf("{%f, %f, %f}",v->x,v->y,v->z);
}

void printV2D(VECT2D* v) {
    printf("{%f, %f}",v->x,v->y);
}

double nm(double (*func)(double x, double y),double xi, double yi, double threshold, void (*callback)(VECT3D* v)) {
    VECT3D      v[3];
    VECT3D      vr,ve,vc;
    VECT2D      vm,start;
    int         vg,vs,vh,i,k;
    double      s,m;

    start.x = xi;
    start.y = yi;

    // pick a dumb initial simplex.
    v[0].x = start.x;
    v[0].y = start.y;
    v[0].z = func(start.x,start.y);

    v[1].x = start.x + 1.0/sqrt(2);
    v[1].y = start.y + 1.0/sqrt(2);

    v[1].x = start.x + 1;
    v[1].y = start.y;
    v[1].z = func(v[1].x,v[1].y);

    v[2].x = start.x - 1.0/sqrt(2);
    v[2].y = start.y + 1.0/sqrt(2);
    v[2].x = start.x;
    v[2].y = start.y + 1;
    v[2].z = func(v[2].x,v[2].y);


    for (k=0;k<NM_MAX_ITERATIONS;k++) {

        if (callback != NULL)
            callback(v);

        #ifdef NM_DEBUG
        printf("ITERATION #%d\n",k+1);
        printf("v={");
        printV3D(&v[0]);
        printf(",\n\t ");
        printV3D(&v[1]);
        printf(",\n\t ");
        printV3D(&v[2]);
        printf("}\n");
        #endif

        //find max & min
        for (i=1,vg=0,vs=0;i<3;i++) {
            if (v[i].z>v[vg].z)
                vg=i;
            else if (v[i].z<v[vs].z)
                vs=i;
        }

        vh = 3-vs-vg; //the other v

        #ifdef NM_DEBUG
        printf("vg=%d, vs=%d, vh=%d\n",vg,vs,vh);
        #endif

        assert((vg!=vs) && (vg!=vh) && (vs!=vh));

        //centroid
        vm.x = (v[vs].x + v[vh].x)/2.0;
        vm.y = (v[vs].y + v[vh].y)/2.0;

        //reflect
        vr.x = vm.x+NM_COF_REFLECT*(vm.x-v[vg].x);
        vr.y = vm.y+NM_COF_REFLECT*(vm.y-v[vg].y);
        vr.z = func(vr.x,vr.y);
        if ((vr.z < v[vh].z) && (vr.z >= v[vs].z)) {
            v[vg].x = vr.x;
            v[vg].y = vr.y;
            v[vg].z = vr.z;
        }

        //expansion?
        if (vr.z < v[vs].z) {
            ve.x = vm.x+NM_COF_EXPAND*(vr.x-vm.x);
            ve.y = vm.y+NM_COF_EXPAND*(vr.y-vm.y);
            ve.z = func(ve.x,ve.y);

            if (ve.z < vr.z) {      //go with the expansion
                v[vg].x = ve.x;
                v[vg].y = ve.y;
                v[vg].z = ve.z;
            }
            else {                  //stick with the reflect
                v[vg].x = vr.x;
                v[vg].y = vr.y;
                v[vg].z = vr.z;
            }
        }

        // contraction?
        if (vr.z >= v[vh].z) {
            // outside
            if ((vr.z < v[vg].z) && (vr.z >= v[vh].z)) {
                vc.x = vm.x+NM_COF_CONTRACT*(vm.x-v[vg].x);
                vc.y = vm.x+NM_COF_CONTRACT*(vm.y-v[vg].y);
                vc.z = func(vc.x,vc.y);
            }
            // inside
            else {
                vc.x = vm.x-NM_COF_CONTRACT*(vm.x-v[vg].x);
                vc.y = vm.y-NM_COF_CONTRACT*(vm.y-v[vg].y);
                vc.z = func(vc.x,vc.y);
            }

            if (vc.z < v[vg].z) {
                v[vg].x = vc.x;
                v[vg].y = vc.y;
                v[vg].z = vc.z;
            }
            // contraction failed, so half the d from vs to all verts.
            else {
                v[vg].x = v[vs].x+(v[vg].x-v[vs].x)/2.0;
                v[vg].y = v[vs].y+(v[vg].y-v[vs].y)/2.0;
                v[vg].z = func(v[vg].x,v[vg].y);

                v[vh].x = v[vs].x+(v[vh].x-v[vs].x)/2.0;
                v[vh].y = v[vs].y+(v[vh].y-v[vs].y)/2.0;
                v[vh].z = func(v[vh].x,v[vh].y);
            }
        }

        // converged?
        m = (v[vg].z + v[vh].z + v[vs].z)/3.0;
        s = sqrt(pow(v[vg].z-m,2.0)/2.0 + pow(v[vh].z-m,2.0)/2.0 + pow(v[vs].z-m,2.0)/2.0);
        if (s < threshold)
            break;

        #ifdef NM_DEBUG
        printf("-------------------------------------------------------------------------------\n\n");
        #endif
    }

    //find the index of small guy
    for (i=1,vg=0,vs=0;i<3;i++)
        if (v[i].z<v[vs].z)
            vs=i;
    return v[vs].z;
}
