learn_2s_som.m 17.5 KB
Newer Older
1
function [StsMap sMap_denorm Resultout sMapPTout] = learn_2s_som(A,nb_neurone,varargin)
2 3
% Cree la carte SOM ou S2-SOM Pour donnees cachees
%
4 5 6 7 8 9
% Usage:
%
%    [sMap, sMap_denorm, Result, sMapPT] = learn_2s_som(A, nb_neurone, <OPTIOS>)
%
%    St = learn_2s_som(A,nb_neurone, '-struct', <OPTIOS>)
%
10 11 12 13 14
% En entree obligatoire
%
%   A: les donnees cachees
%   nb_neurone: Nombre de neurones 
%
15 16 17 18
% En option, elles sont specifies par couples de valeurs, p. exemple:
%
%         'radius', [ 5 1 ], 'trainlen', 20, 'tracking', 1, ...
%  
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
%
%   radius: en forme de vecteur, chaque deux elements qui ce suivent constitue
%           une temperature [i..i+1],[i+1..i+2],....
%   trainlen: en forme de vecteur: chaque element constitue une itération de
%           l'entraienement. NB:vecteur radius doit avoir un element en plus
%           que le vecteur trainlen.
%   tracking: pour visualiser l'apprentissage.
%
%   'S2-SOM': pour faire l'apprentissage avec S2-SOM. Si 'S2-SOM' est
%           specifie alors il faut d'autres parametres:
%
%   DimData: vecteur contenant la dimention de chaque bloc.
%   lambda: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           blocs.
%   eta: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           variables.
%
% En sortie
%
38
%   sMap: La carte SOM ou S2-SOM au point de meilleur "Perf".
39 40 41
%
%   sMap_denorm: La carte SOM ou S2-SOM, denormalisee.
%
42 43
%   iBest = indice dans Result de la carte sMap retournee.
%
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
%   Result: structure (vecteur) avec les sorties ou resultats de chaque cas
%           entraine (avec une paire distincte de la combinaison entre lambda
%           et eta): sMap, bmus, Alpha, Beta, Perf.
%
%           Champs de Result:
%              sMap: La carte SOM ou S2-SOM du cas.
%              bmus: Bmus (best matching units) sur toute la zone.
%              Alpha: Coefficients Alpha multipliant les groupes.
%              Beta: Coefficients Alpha multipliant les variables au sans de la
%                  Carte Topologique.
%              Perf: parametre "distortion measure for the map", calcule par
%                  la fonction som_distortion.
%
%   (bmus_pixel:(best matching units) par pixel.)
% Detailed explanation goes here
  
  
61 62 63 64
% Valeurs par defaut
  tracking       = 0;
  init           = 'lininit';
  lattice        = 'rect';
65 66
    
  % flags et variables associees
67 68 69 70 71 72 73 74
  bool_verbose        = false;
  bool_return_struct  = false;
  bool_norm           = false; type_norm     = 'simple';
  bool_rad            = false; rad           = [5 1];
  bool_trainlen       = false; trlen         = 20;
  bool_rad_2s_som     = false; rad_2s_som    = [];
  bool_trlen_2s_som   = false; trlen_2s_som  = [];
  bool_2ssom          = false;
75
  bool_DimData        = false; DimData       = [size(A,2)];
76 77
  bool_lambda         = false; lambda        = 1;
  bool_eta            = false; eta           = 1000;
78
  bool_parcomp        = false; parcomp_workers = 8; % 8 workers for parallel computing by default, if activated (if bool_parcomp is true)
79

80
  Result              = struct([]);
81
  
82
  bool_init_with_make = true;
83
  bool_pre_training   = true;
84 85 86 87 88
  
  %recuperer les donnees
  data.data=A;
  label=[1:size(data.data,2)];
  
89
  %Labelise les donnees (affectation apres boucle d'arguments (selon la valeur de DimBloc)
90 91 92 93
  ListVar={};
  
  data_casename='simulation';
  
94 95 96 97
  % --- CM pour ajouter les arguments 'data_name' et 'comp_names'
  i=1;
  while (i<=length(varargin))
    if ischar(varargin{i})
98
      switch lower(varargin{i}),
99
        case { 'verbose', '-verbose' },
100
          bool_verbose = true;
101 102
        case { 'returnstruct', 'return-struct', 'struct', '-return-struct', '-struct' },
          bool_return_struct = true;
103
        case { 'data_name' },
104
          data_casename = varargin{i+1}; i=i+1;
105
        case { 'comp_names' },
106
          ListVar = varargin{i+1}; i=i+1;
107
        case { 'norm' },
108
          bool_norm = true; 
109 110 111 112 113 114 115 116
          type_norm = varargin{i+1}; i=i+1;
        case { 'init' },
          init = varargin{i+1}; i=i+1;
        case { 'tracking' },
          tracking = varargin{i+1}; i=i+1;
        case { 'lattice' },
          lattice = varargin{i+1}; i=i+1;
        case 'radius'
117
          bool_rad = true;
118 119
          rad = varargin{i+1}; i=i+1;
        case 'trainlen' 
120
          bool_trainlen = true;
121
          trlen = varargin{i+1}; i=i+1;
122 123 124 125 126 127
        case 'radius-2s-som'
          bool_rad_2s_som = true;
          rad_2s_som = varargin{i+1}; i=i+1;
        case 'trainlen-2s-som' 
          bool_trlen_2s_som = true;
          trlen_2s_som = varargin{i+1}; i=i+1;
128
        case {'s2-som', '2s-som'},
129
          disp('** S2-SOM Active **');
130
          bool_2ssom = true;
131
        case {'no-s2-som', 'no-2s-som'},
132 133 134
          disp('** S2-SOM Inactive, only SOM training **');
          bool_2ssom = false;
        case 'dimdata'
135 136 137 138
          DimData = varargin{i+1}; i=i+1;
          for di=1:length(DimData)
            DimBloc(di).Dim = DimData(di);
          end
139
          bool_DimData = true;
140 141 142 143 144
        case 'lambda' 
          lambda=varargin{i+1}; i=i+1;
          if length(lambda) < 1
            error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
145 146 147 148 149
          if isscalar(lambda) && lambda <= 0
            bool_lambda = false;
          else
            bool_lambda = true;
          end
150 151 152 153 154
        case 'eta' 
          eta = varargin{i+1}; i=i+1;
          if length(eta) < 1
            error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
155 156 157 158 159
          if isscalar(eta) && eta <= 0
            bool_eta = false;
          else
            bool_eta = true;
          end
160 161 162
        case 'parcomp'
          bool_parcomp = true;
          parcomp_workers = varargin{i+1}; i=i+1;
163 164 165 166
        case 'ini-with-make'
          bool_init_with_make = true;
        case 'no-ini-with-make'
          bool_init_with_make = false;
167 168 169 170
        case 'pre-training'
          bool_pre_training   = true;
        case 'no-pre-training'
          bool_pre_training   = false;
171 172 173
        otherwise
          error(sprintf(' *** %s error: argument(%d) ''%s'' inconnu ***\n', ...
                        mfilename, i, varargin{i}));
174
      end
175 176 177
    else
      error(sprintf(' *** %s error: argument non-string inattendu (en %d-iemme position) ***\n', ...
                    mfilename, i));
178 179 180
    end
    i=i+1;
  end
181

182 183 184 185 186 187 188 189 190 191 192 193
  if isempty(ListVar),
    kVar = 1;
    for iG = 1:length(DimData),
      szG = DimData(iG);
      for l = 1:szG
        ListVar{kVar,1} = sprintf('Gr%dVar%d', iG, l);
        kVar = kVar + 1;
      end
    end
  end

  data.colheaders = ListVar;
194

195
  sD = som_data_struct(data.data,'name', data_casename,'comp_names', upper(ListVar));
196 197 198 199 200 201 202 203
  % i=1;
  % while (i<=length(varargin) && bool_norm==0)
  %   if strcmp(varargin{i},'norm')
  %     bool_norm=1; 
  %     type_norm=varargin{i+1};
  %   end
  %   i=i+1;
  % end
204 205 206 207 208 209 210 211 212
  
  %normalisation des donnees
  if bool_norm
    fprintf(1,'\n-- Normalisation des donnees selon ''%s'' ...\n', type_norm);
    if strcmp(type_norm,'simple')
      sD_norm=som_normalize(sD);
    else
      sD_norm=som_normalize(sD,type_norm);
    end
213
  else
214
    fprintf(1,'\n** Pas de normalisation des donnees **\n');
215 216 217 218
    sD_norm = sD;
  end
  
  
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if strcmp(varargin{i},'init')
  %       init=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'tracking')
  %       tracking=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'lattice')
  %       lattice=varargin{i+1};
  %     end
  %     i=i+1;
  %   end
  % end
234

235 236 237 238 239
  fprintf(1,[ '\n-- ------------------------------------------------------------------\n', ...
              '-- New 2S-SOMTraining function:\n', ...
              '--   %s (''%s'', ''%s'', ''%s'', ... )\n', ...
              '-- ------------------------------------------------------------------\n' ], ...
          mfilename, init, lattice, data_casename);
240
    
241
  %SOM initialisation
242
  if bool_init_with_make
243
    fprintf(1,'\n-- Initialisation avec SOM_MAKE ... ')
244
    sMap=som_make(sD_norm, ...
245 246 247 248 249 250 251 252
                  'munits',   nb_neurone, ...
                  'lattice',  lattice, ...
                  'init',     init, ...
                  'tracking', tracking); % creer la carte initiale avec et effectuer un entrainenemt

  else
    if strcmp(init,'randinit')
      fprintf(1,'\n-- Initialisation avec SOM_RANDINIT ... ')
253
      sMap=som_randinit(sD_norm, ...
254 255 256 257 258 259
                        'munits',   nb_neurone, ...
                        'lattice',  lattice, ...
                        'tracking', tracking); % creer la carte initiale

    elseif strcmp(init,'lininit')
      fprintf(1,'\n-- Initialisation avec SOM_LININIT ... ')
260
      sMap=som_lininit(sD_norm, ...
261 262 263 264 265 266 267 268 269 270 271 272
                       'munits',   nb_neurone, ...
                       'lattice',  lattice, ...
                       'tracking', tracking); % creer la carte initiale

    else
      error(sprintf(['\n *** %s error: invalid ''init'' option ''%s'' ***\n', ...
                     '     Shoud be one between { ''lininit'', ''randinit'' } ***\n' ], ...
                    mfilename, init));
    end
    fprintf(1,' <som init END>.\n')
  end
  
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
  % bool_rad=0;
  % bool_trainlen=0;
  % if ~isempty(varargin)
  % 
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i})
  %       switch varargin{i}
  %         case 'radius'
  %           bool_rad=1;
  %           loc_rad=i;
  %           rad=varargin{loc_rad+1};
  %           i=i+1;
  %         case 'trainlen' 
  %           bool_trainlen=1;
  %           loc_trainlen=i;
  %           trlen=varargin{loc_trainlen+1};
  %           i=i+1;
  %         otherwise
  %           i=i+1;
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end

299
  if bool_pre_training
300 301
    pretrain_tracking = tracking;
    %pretrain_tracking = 1;
302
    
303 304 305
    % batchtrain avec radius ...
    if (bool_rad && ~bool_trainlen)
      fprintf(1,'\n-- BATCHTRAIN initial avec radius ... ')
306
      if pretrain_tracking, fprintf(1,'\n'); end
307 308 309
      j=1;
      while j<length(rad)
        
310
        sMap=som_batchtrain(sMap, sD_norm, ...
311 312
                            'radius',[rad(j) rad(j+1)], ...
                            'tracking',pretrain_tracking);
313 314
        j=j+1;
        
315 316
      end
    end
317 318 319
    % batchtrain avec trainlen ...
    if (~bool_rad && bool_trainlen) 
      fprintf(1,'\n-- BATCHTRAIN initial avec trainlen ... ')
320
      if pretrain_tracking, fprintf(1,'\n'); end
321 322 323
      j=1;
      while j<=length(trlen)
        
324
        sMap=som_batchtrain(sMap, sD_norm, ...
325 326
                            'trainlen',trlen(j), ...
                            'tracking',pretrain_tracking);
327 328 329 330 331
        j=j+1;
        
      end
    end
    % batchtrain avec radius et trainlen             
332
    if (bool_rad && bool_trainlen)
333
      fprintf(1,'\n-- BATCHTRAIN initial avec radius et trainlen ... \n')
334
      if pretrain_tracking, fprintf(1,'\n'); end
335 336
      if length(rad)==length(trlen)+1
        
337 338 339
        j=1;
        while j<length(rad)
          
340
          sMap=som_batchtrain(sMap, sD_norm, ...
341 342 343
                              'radius',[rad(j) rad(j+1)], ...
                              'trainlen',trlen(j), ...
                              'tracking',pretrain_tracking);
344 345 346
          j=j+1;
          
        end
347 348
      else
        error('vecteur radius doit avoir un element en plus que le vecteur trainlen ')
349
      end
350
    end
351 352
    sMapPT = sMap;

353 354 355 356
    current_perf = som_distortion(sMap,sD_norm);
    fprintf(1,'--> som_distortion apres entrainement initiale = %s\n', num2str(current_perf));
    
  else
357
    fprintf(1,'** batchtrain initial (pre-training) non active **\n')
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
  end
  % end
  
  % %S2-SOM
  % bool_2ssom=0;
  % bool_DimData=0;
  % bool_lambda=0;
  % bool_eta=0;
  %
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i}) 
  %       switch varargin{i} 
  %      
  %         case 'S2-SOM'
  %           disp('** S2-SOM Active **');
  %           bool_2ssom=1;
  %           i=i+1;
  %           %mettre en bloc
  %         case 'DimData'
  %           i=i+1;
  %           DimData=varargin{i};
  %           for di=1:length(DimData)
  %             DimBloc(di).Dim=DimData(di);
  %           end
  %           bool_DimData=1;
  %         case 'lambda' 
  %           i=i+1; 
  %           lambda=varargin{i};
  %           if length(lambda) < 1
  %             error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_lambda=1;
  %         case 'eta' 
  %           i=i+1; eta=varargin{i};
  %           if length(eta) < 1
  %             error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_eta=1;
  %         otherwise
  %           i=i+1;
  %
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end
406

407
  if (bool_2ssom)
408 409
      %if (bool_lambda && bool_eta && bool_DimData)
      if (bool_lambda && bool_eta)
410
          
411 412 413 414
          n_lambda = length(lambda);
          n_eta    = length(eta);
          i_train = 1;
          n_train = n_lambda*n_eta;
415
          
416 417 418 419 420 421 422
          if ~bool_rad_2s_som
              rad_2s_som =  [rad(round(length(rad)/2)) ...
                  rad((round(length(rad)/2))+1)];
          end
          if ~bool_trlen_2s_som
              trlen_2s_som = trlen(round(length(trlen)/2));
          end
423
          
424 425 426 427 428 429
          fprintf(1,[ '\n-- batchtrainRTOM loop for %d lambda and %d eta values:\n', ...
              '-- ------------------------------------------------------------------\n' ], ...
              n_lambda, n_eta);
          if tracking > 1,
              fprintf(1,'   ... trainlen_2s_som ... %s\n', num2str(trlen_2s_som))
              fprintf(1,'   ... radius_2s_som ..... [%s]\n', join(string(rad_2s_som),', '))
430
          end
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
          if bool_parcomp, ticBytes(gcp); end   % POUR CALCUL PARALLELE
          if bool_parcomp
              parcomp_M = parcomp_workers;
          else
              parcomp_M = 1;
          end
          parfor (i=1:n_lambda,parcomp_M)
              for j=1:n_eta
                  fprintf(1,'-- batchtrainRTOM (%d/%d) with lambda=%s and eta=%s ... ',(i - 1) * n_eta + j, ...
                      n_train, num2str(lambda(i)),num2str(eta(j)));
                  if tracking, fprintf(1,'\n'); end
                  
                  [Result(i,j).sMap Result(i,j).bmus Result(i,j).Alpha Result(i,j).Beta] = som_batchtrainRTOM( ...
                      sMap, sD_norm, ...
                      'TypeAlgo', '2SSOM', ...
                      'DimData',  DimData, ...
                      'DimBloc',  DimBloc, ...
                      'lambda',   lambda(i), ...
                      'eta',      eta(j), ...
                      'radius',   rad_2s_som, ...
                      'trainlen', trlen_2s_som, ...
                      'tracking', tracking);
                  
                  Result(i,j).lambda  = lambda(i);
                  Result(i,j).eta     = eta(j);
                  Result(i,j).DimData = DimData;
                  
                  current_perf = som_distortion(Result(i,j).sMap,sD_norm);
                  fprintf(1,'   --> som_distortion=%s\n', num2str(current_perf));
                  
                  Result(i,j).Perf = current_perf;
              end
          end
          if bool_parcomp, tocBytes(gcp), end   % POUR CALCUL PARALLELE
465
          
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
          % best_i   = 0;
          % best_j   = 0;
          % bestperf = inf;
          % for i=1:n_lambda
          %   for j=1:n_eta
          %       %  end
          %       %end
          %       % best_i=0;
          %       % best_j=0;
          %       % bestperf=inf;
          %       % for i=1:n_lambda
          %       %   for j=1:n_eta
          %       %
          %       if Result(i,j).Perf < bestperf
          %           best_i = i;
          %           best_j = j;
          %           bestperf = Result(i,j).Perf;
          %       end
          %   end
          % end
          %
          % sMap = Result(best_i,best_j).sMap;
          
      else
          error('manque de parametre: specifier les valeurs pour LAMBDA, pour ETA ou pour les deux!')
491
      end
492 493 494 495 496 497
  elseif (bool_lambda || bool_eta)
    error([ '*** %s: PAS DE 2SSOM SPECIFIE MAIS FLAGS (LAMBDA ou ETA) ACTIVE ***\n', ...
            '    mentionnez si vous voulez ''S2-SOM''\n' ], mfilename)
  else
    fprintf(1,[ '*** %s: PAS DE 2SSOM SPECIFIE ***\n', ...
                '    mentionnez si vous voulez ''S2-SOM''\n' ], mfilename)
498 499
  end
  
500
  % end
501
  
502 503 504 505 506 507 508 509 510 511 512
  clear St
  if (bool_2ssom)
    % si 2S-SOM alors best perf sMap
    [BestPerf,iBest] = min(cell2mat({Result.Perf}));
    
    sMap = Result(iBest).sMap;
  else
    % sinon, si pas 2S-SOM
    sMap = sMapPT;
  end
  
513 514
  % denormalisation de la Map
  if bool_norm
515 516 517 518 519 520 521 522
    sMap_dnrm = som_denormalize(sMap,sD_norm.comp_norm);
  else
    sMap_dnrm = sMap;
  end
  
  if bool_return_struct
    % Si retour STRUCT
    St.sMap     = sMap;
523
    St.sD       = sD;
524 525 526 527 528
    
    if ~bool_2ssom
      St.bmus = som_bmus(sMap,sD);
    end
    
529 530
    if bool_norm
      St.sMap_denorm = sMap_dnrm;
531
      St.sD_norm = sD_norm;
532 533 534
    end
    
    if (bool_2ssom)
535 536 537 538 539 540
      St.lambda  = Result(iBest).lambda;
      St.eta     = Result(iBest).eta;
      St.bmus    = Result(iBest).bmus;
      St.Alpha   = Result(iBest).Alpha;
      St.Beta    = Result(iBest).Beta;
      St.DimData = Result(iBest).DimData;
541 542 543 544 545 546 547 548 549 550 551
      
      St.sMapPT   = sMapPT;
      if bool_norm
        St.sMapPT_denorm = som_denormalize(sMapPT,sD_norm.comp_norm);
      end
      
      St.Result   = Result;
      St.iBest    = iBest;
    end
    
    StsMap = St; % variable de retour
552
  else
553 554 555 556 557
    % Sinon, alors retour variables  ... (sMap sMap_denorm Result)
    StsMap      = sMap;
    sMap_denorm = sMap_dnrm;
    Resultout   = Result;
    sMapPTout   = sMapPT;
558 559 560
  end
  
  return