laguerreTransform.cc 4.29 KB
Newer Older
1
#include <cmath>
2 3 4
#include <iostream>
#include <algorithm> // for copy
#include <iterator> // for ostream_iterator
5

6
#ifdef CBLAS
7 8 9
#ifdef Darwin
#include <Accelerate/Accelerate.h>
#elif Linux
10 11
#include <cblas.h>
#endif
12 13
#endif

14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
using namespace std;

#include "laguerreTransform.h"
#include "lagsht_exceptions.h"

#define DEBUG 1

namespace LagSHT {


LaguerreTransform::LaguerreTransform(int N, r_8 R, int alpha) : 
  BaseLaguerreTransform(N,alpha), R_(R) {
  
#if DEBUG >= 1
  cout << "LaguerreTransform start...." <<endl;
#endif

  if(alpha<0)  throw LagSHTError("LagTransform call with alpha<0");
  LaguerreFuncQuad lag(N_,alpha_);
  lag.QuadWeightNodes(nodes_,weights_);
34 35


36 37
  int alphaFact = 1;
  for(int i=1;i<=alpha_; i++) alphaFact *= i;
38

Jean-Eric Campagne's avatar
Jean-Eric Campagne committed
39
 
40 41 42 43
  alphaFact_ = sqrt((r_8)alphaFact); //sqrt(alpha!)

}//Ctor

44
int LaguerreTransform::R2Index(r_8 r) const {
45 46 47 48 49 50 51
  using namespace std;
  r_8 rscaled = r*nodes_[N_-1]/R_;
  vector<r_8> dist(N_);
  transform(nodes_.begin(),nodes_.end(),dist.begin(), distance(rscaled));
  return min_element(dist.begin(),dist.end()) - dist.begin();
}//RadIndex

52 53 54 55

void LaguerreTransform::MultiAnalysis(const vector< complex<r_8> >& fi, 
				      vector< complex<r_8> >& fn, int stride) {

56
  if(fi.size() != (size_t)(N_*stride) ) throw LagSHTError("LagTransform::Analysis size error");
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    fn.resize(N_*stride);

#if DEBUG >= 1
    cout << "Analysis with Function...." << endl;
#endif
#if DEBUG >= 2
    std::copy(fi.begin(), fi.end(), std::ostream_iterator<char>(std::cout, " "));
#endif                     
    r_8 invalphaFact = 1.0/alphaFact_; //1/alpha!                   

    vector<r_8> facts(N_); //sqrt(n!/(n+alpha)!) en iteratif
    facts[0] = invalphaFact;
    for(int n=1;n<N_;n++) facts[n] = facts[n-1]*sqrt(((r_8)n)/((r_8)(n+alpha_)) );
    
    LaguerreFuncQuad lag(N_-1);
72
    vector<r_8> LnAll(N_); //all the values Lag_n(r) n:0...N-1
73 74 75 76 77 78

    vector<r_8> LnkMtx(N_*N_); //all the values Lag_n(r) n:0...N-1 TRY r_16->r_8
    for (int k=0; k<N_; k++){
      r_8 rk = nodes_[k];
      lag.Values(rk,LnAll);
      for (int n = 0; n<N_; n++ ){
79
	//	LnkMtx[n*N_+k] = LnAll[n]*facts[n];
80
	LnkMtx[n+N_*k] = LnAll[n]*facts[n]*weights_[k];
81 82 83 84 85
      }
    }



86 87 88
#ifdef CBLAS
    cblas_dgemm (CblasColMajor, CblasNoTrans, CblasTrans, 2*stride, N_, N_, 1., (double*)(&fi[0]), 2*stride, &LnkMtx[0], N_, 0, (double *)(&fn[0]), 2*stride);
#else
89 90 91 92 93 94 95
    vector<complex<r_8> > vtmp(N_);
    for (int l=0; l<stride; l++) {
      vtmp.assign(N_,0.);
      for (int i = 0; i<N_; i++ ){
	complex<r_8> fli = fi[l+ i*stride];
	r_8 wi = weights_[i];
	for (int n=0; n<N_; n++){
96
	  //	  vtmp[n] +=  fli * wi * LnkMtx[n*N_+i];
97
	  vtmp[n] +=  fli * LnkMtx[n+N_*i];
98 99 100 101 102 103
	}//loop on k
      }//loop on n
      for (int n=0; n<N_; n++) {
	fn[l+n*stride] += vtmp[n];
      }
    }//loop on l
104
#endif
105 106 107 108 109 110 111

}//MultiAnalysis


void LaguerreTransform::MultiSynthesis(const vector< complex<r_8> >& fn, 
				       vector< complex<r_8> >& fi, int stride) {

112
  if(fn.size() != (size_t)(N_*stride) ) throw LagSHTError("LagTransform::Synthesis size error");
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    fi.resize(N_*stride);

#if DEBUG >= 1
    cout << "Multi Synthesis with Function....:" << N_ << endl;
#endif
#if DEBUG >= 2
    std::copy(fn.begin(), fn.end(), std::ostream_iterator<char>(std::cout, " "));
#endif
    r_8 invalphaFact = 1.0/alphaFact_; //1/alpha! 

    vector<r_8> facts(N_); //sqrt(n!/(n+alpha)!) en iteratif
    facts[0] = invalphaFact;
    for(int n=1; n<N_; n++) facts[n] = facts[n-1]*sqrt(((r_8)n)/((r_8)(n+alpha_)) );
    

    LaguerreFuncQuad lag(N_-1);
129
    vector<r_8> LnAll(N_); //all the values Lag_n(r) n:0...N-1
130 131 132 133 134 135 136 137 138
    vector<r_8> LnkMtx(N_*N_); //all the values Lag_n(r) n:0...N-1 TRY r_16->r_8
    for (int k=0; k<N_; k++){
      r_8 rk = nodes_[k];
      lag.Values(rk,LnAll);
      for (int n = 0; n<N_; n++ ){
	LnkMtx[n*N_+k] = LnAll[n]*facts[n];
      }
    }
    
139 140 141
#if CBLAS
    cblas_dgemm (CblasColMajor, CblasNoTrans, CblasTrans, 2*stride, N_, N_, 1., (double*)(&fn[0]), 2*stride, &LnkMtx[0], N_, 0, (double *)(&fi[0]), 2*stride);
#else
142 143 144 145 146 147 148 149 150 151 152
    vector<complex<r_8> > vtmp(N_);
    for (int l=0; l<stride; l++) {
      vtmp.assign(N_,0.);
      for (int n = 0; n<N_; n++ ){
	complex<r_8> fln = fn[l+ n*stride];
	for (int k=0; k<N_; k++){
	  vtmp[k] +=  fln * LnkMtx[n*N_+k];
	}//loop on k
      }//loop on n
      for (int k=0; k<N_; k++) fi[l+k*stride] += vtmp[k];
    }//loop on l
153
#endif
154 155 156 157 158
}//MultiSynthesis



}//end namespace