/*
************************************************************************
*
*   ConjGrad.c - minimization by conjugate gradients
*
*   Copyright (c) 1996
*
*   ETH Zuerich
*   Institut fuer Molekularbiologie und Biophysik
*   ETH-Hoenggerberg
*   CH-8093 Zuerich
*
*   SPECTROSPIN AG
*   Industriestr. 26
*   CH-8117 Faellanden
*
*   All Rights Reserved
*
*   Date of last modification : 96/05/21
*   Pathname of SCCS file     : /sgiext/molmol/tools/src/SCCS/s.ConjGrad.c
*   SCCS identification       : 1.3
*
************************************************************************
*/

#include <conj_grad.h>

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include <brent.h>

/* Conjugate gradient (Fletcher-Reeves-Polak-Ribiere) method for
   function minimization.
   Algorithm adapted from Numerical Recipes. */

#define TOL 2.0e-4
#define ITMAX 200
#define EPS 1.0e-10

static ConjGradMinFunc MinFuncCom;
static float *PCom, *XiCom;
static int NCom;
static void *ClientDataCom;
static float *Xt, *Df;

static float
f1dim(float x, void *clientData)
{
  float f;
  int j;

  for (j = 0; j < NCom; j++)
    Xt[j] = PCom[j] + x * XiCom[j];

  (void) MinFuncCom(Xt, NCom, ClientDataCom, &f, Df);

  return f;
}

static void
df1dim(float x, void *clientData, float *fP, float *dP)
{
  int j;
  float df1 = 0.0;

  for (j = 0; j < NCom; j++)
    Xt[j] = PCom[j] + x * XiCom[j];

  (void) MinFuncCom(Xt, NCom, ClientDataCom, fP, Df);

  for (j = 0; j < NCom; j++)
    df1 += Df[j] * XiCom[j];

  *dP = df1;
}


static float 
linmin(void)
{
  int j;
  float xx, xmin, fx, fb, fa, bx, ax, fret;

  ax = 0.0;
  xx = 1.0;
  bx = 2.0;
  BrentBracketMin(f1dim, NULL, &ax, &xx, &bx, &fa, &fx, &fb);
  (void) BrentSolveDerMin(df1dim, NULL, ax, xx, bx, TOL, &xmin, &fret);

  for (j = 0; j < NCom; j++) {
    XiCom[j] *= xmin;
    PCom[j] += XiCom[j];
  }

  return fret;
}

BOOL
ConjGradMin(ConjGradMinFunc func, float p[], int n, void *clientData,
    float ftol)
{
  BOOL done;
  int j, its;
  float fret, gg, gam, fp, dgg;
  float *g, *h, *xi;

  xi = malloc(n * sizeof(*xi));
  done = func(p, n, clientData, &fp, xi);
  if (done) {
    free(xi);
    return TRUE;
  }

  g = malloc(n * sizeof(*g));
  h = malloc(n * sizeof(*h));

  MinFuncCom = func;
  PCom = p;
  NCom = n;
  XiCom = xi;
  ClientDataCom = clientData;
  Xt = malloc(n * sizeof(*Xt));
  Df = malloc(n * sizeof(*Df));

  for (j = 0; j < n; j++) {
    g[j] = - xi[j];
    h[j] = g[j];
    xi[j] = g[j];
  }

  for (its = 0; its < ITMAX; its++) {
    fret = linmin();

    if (2.0 * fabs(fret - fp) <= ftol * (fabs(fret) + fabs(fp) + EPS)) {
      free(g);
      free(h);
      free(xi);
      free(Xt);
      free(Df);

      return TRUE;
    }

    done = func(p, n, clientData, &fp, xi);
    if (done) {
      free(g);
      free(h);
      free(xi);
      free(Xt);
      free(Df);

      return TRUE;
    }

    dgg = gg = 0.0;
    for (j = 0; j < n; j++) {
      gg += g[j] * g[j];
      dgg += (xi[j] + g[j]) * xi[j];
    }
    if (gg == 0.0) {
      free(g);
      free(h);
      free(xi);
      free(Xt);
      free(Df);

      return TRUE;
    }

    gam = dgg / gg;
    for (j = 0; j < n; j++) {
      g[j] = - xi[j];
      h[j] = g[j] + gam * h[j];
      xi[j] = h[j];
    }
  }

  free(g);
  free(h);
  free(xi);
  free(Xt);
  free(Df);

  return FALSE;
}
