learn_2s_som.m 15.6 KB
Newer Older
1
function [StsMap sMap_denorm Resultout sMapPTout] = learn_2s_som(A,nb_neurone,varargin)
2 3
% Cree la carte SOM ou S2-SOM Pour donnees cachees
%
4 5 6 7 8 9
% Usage:
%
%    [sMap, sMap_denorm, Result, sMapPT] = learn_2s_som(A, nb_neurone, <OPTIOS>)
%
%    St = learn_2s_som(A,nb_neurone, '-struct', <OPTIOS>)
%
10 11 12 13 14
% En entree obligatoire
%
%   A: les donnees cachees
%   nb_neurone: Nombre de neurones 
%
15 16 17 18
% En option, elles sont specifies par couples de valeurs, p. exemple:
%
%         'radius', [ 5 1 ], 'trainlen', 20, 'tracking', 1, ...
%  
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
%
%   radius: en forme de vecteur, chaque deux elements qui ce suivent constitue
%           une temperature [i..i+1],[i+1..i+2],....
%   trainlen: en forme de vecteur: chaque element constitue une itération de
%           l'entraienement. NB:vecteur radius doit avoir un element en plus
%           que le vecteur trainlen.
%   tracking: pour visualiser l'apprentissage.
%
%   'S2-SOM': pour faire l'apprentissage avec S2-SOM. Si 'S2-SOM' est
%           specifie alors il faut d'autres parametres:
%
%   DimData: vecteur contenant la dimention de chaque bloc.
%   lambda: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           blocs.
%   eta: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           variables.
%
% En sortie
%
38
%   sMap: La carte SOM ou S2-SOM au point de meilleur "Perf".
39 40 41
%
%   sMap_denorm: La carte SOM ou S2-SOM, denormalisee.
%
42 43
%   iBest = indice dans Result de la carte sMap retournee.
%
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
%   Result: structure (vecteur) avec les sorties ou resultats de chaque cas
%           entraine (avec une paire distincte de la combinaison entre lambda
%           et eta): sMap, bmus, Alpha, Beta, Perf.
%
%           Champs de Result:
%              sMap: La carte SOM ou S2-SOM du cas.
%              bmus: Bmus (best matching units) sur toute la zone.
%              Alpha: Coefficients Alpha multipliant les groupes.
%              Beta: Coefficients Alpha multipliant les variables au sans de la
%                  Carte Topologique.
%              Perf: parametre "distortion measure for the map", calcule par
%                  la fonction som_distortion.
%
%   (bmus_pixel:(best matching units) par pixel.)
% Detailed explanation goes here
  
  
61 62 63 64
% Valeurs par defaut
  tracking       = 0;
  init           = 'lininit';
  lattice        = 'rect';
65 66
    
  % flags et variables associees
67 68 69 70 71 72 73 74 75 76 77
  bool_verbose        = false;
  bool_return_struct  = false;
  bool_norm           = false; type_norm     = 'simple';
  bool_rad            = false; rad           = [5 1];
  bool_trainlen       = false; trlen         = 20;
  bool_rad_2s_som     = false; rad_2s_som    = [];
  bool_trlen_2s_som   = false; trlen_2s_som  = [];
  bool_2ssom          = false;
  bool_DimData        =  false; DimData       = [size(A,2)];
  bool_lambda         = false; lambda        = 1;
  bool_eta            = false; eta           = 1000;
78

79
  Result              = [];
80
  
81
  bool_init_with_make = true;
82
  bool_pre_training   = true;
83 84 85 86 87
  
  %recuperer les donnees
  data.data=A;
  label=[1:size(data.data,2)];
  
88
  %Labelise les donnees (affectation apres boucle d'arguments (selon la valeur de DimBloc)
89 90 91 92
  ListVar={};
  
  data_casename='simulation';
  
93 94 95 96
  % --- CM pour ajouter les arguments 'data_name' et 'comp_names'
  i=1;
  while (i<=length(varargin))
    if ischar(varargin{i})
97
      switch lower(varargin{i}),
98
        case { 'verbose', '-verbose' },
99
          bool_verbose = true;
100 101
        case { 'returnstruct', 'return-struct', 'struct', '-return-struct', '-struct' },
          bool_return_struct = true;
102
        case { 'data_name' },
103
          data_casename = varargin{i+1}; i=i+1;
104
        case { 'comp_names' },
105
          ListVar = varargin{i+1}; i=i+1;
106
        case { 'norm' },
107
          bool_norm = true; 
108 109 110 111 112 113 114 115
          type_norm = varargin{i+1}; i=i+1;
        case { 'init' },
          init = varargin{i+1}; i=i+1;
        case { 'tracking' },
          tracking = varargin{i+1}; i=i+1;
        case { 'lattice' },
          lattice = varargin{i+1}; i=i+1;
        case 'radius'
116
          bool_rad = true;
117 118
          rad = varargin{i+1}; i=i+1;
        case 'trainlen' 
119
          bool_trainlen = true;
120
          trlen = varargin{i+1}; i=i+1;
121 122 123 124 125 126
        case 'radius-2s-som'
          bool_rad_2s_som = true;
          rad_2s_som = varargin{i+1}; i=i+1;
        case 'trainlen-2s-som' 
          bool_trlen_2s_som = true;
          trlen_2s_som = varargin{i+1}; i=i+1;
127
        case 's2-som'
128
          disp('** S2-SOM Active **');
129
          bool_2ssom = true;
130 131 132 133
        case 'no-s2-som'
          disp('** S2-SOM Inactive, only SOM training **');
          bool_2ssom = false;
        case 'dimdata'
134 135 136 137
          DimData = varargin{i+1}; i=i+1;
          for di=1:length(DimData)
            DimBloc(di).Dim = DimData(di);
          end
138
          bool_DimData = true;
139 140 141 142 143
        case 'lambda' 
          lambda=varargin{i+1}; i=i+1;
          if length(lambda) < 1
            error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
144
          bool_lambda = true;
145 146 147 148 149
        case 'eta' 
          eta = varargin{i+1}; i=i+1;
          if length(eta) < 1
            error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
150
          bool_eta = true;
151 152 153 154
        case 'ini-with-make'
          bool_init_with_make = true;
        case 'no-ini-with-make'
          bool_init_with_make = false;
155 156 157 158
        case 'pre-training'
          bool_pre_training   = true;
        case 'no-pre-training'
          bool_pre_training   = false;
159 160 161
        otherwise
          error(sprintf(' *** %s error: argument(%d) ''%s'' inconnu ***\n', ...
                        mfilename, i, varargin{i}));
162
      end
163 164 165
    else
      error(sprintf(' *** %s error: argument non-string inattendu (en %d-iemme position) ***\n', ...
                    mfilename, i));
166 167 168
    end
    i=i+1;
  end
169 170 171 172 173 174 175 176 177 178 179 180 181
  
  if isempty(ListVar),
    kVar = 1;
    for iG = 1:length(DimData),
      szG = DimData(iG);
      for l = 1:szG
        ListVar{kVar,1} = sprintf('Gr%dVar%d', iG, l);
        kVar = kVar + 1;
      end
    end
  end

  data.colheaders = ListVar;
182

183
  sD = som_data_struct(data.data,'name', data_casename,'comp_names', upper(ListVar));
184 185 186 187 188 189 190 191
  % i=1;
  % while (i<=length(varargin) && bool_norm==0)
  %   if strcmp(varargin{i},'norm')
  %     bool_norm=1; 
  %     type_norm=varargin{i+1};
  %   end
  %   i=i+1;
  % end
192 193 194 195 196 197 198 199 200
  
  %normalisation des donnees
  if bool_norm
    fprintf(1,'\n-- Normalisation des donnees selon ''%s'' ...\n', type_norm);
    if strcmp(type_norm,'simple')
      sD_norm=som_normalize(sD);
    else
      sD_norm=som_normalize(sD,type_norm);
    end
201
  else
202
    fprintf(1,'\n** Pas de normalisation des donnees **\n');
203 204 205 206
    sD_norm = sD;
  end
  
  
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if strcmp(varargin{i},'init')
  %       init=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'tracking')
  %       tracking=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'lattice')
  %       lattice=varargin{i+1};
  %     end
  %     i=i+1;
  %   end
  % end
222

223 224 225 226 227
  fprintf(1,[ '\n-- ------------------------------------------------------------------\n', ...
              '-- New 2S-SOMTraining function:\n', ...
              '--   %s (''%s'', ''%s'', ''%s'', ... )\n', ...
              '-- ------------------------------------------------------------------\n' ], ...
          mfilename, init, lattice, data_casename);
228
    
229
  %SOM initialisation
230
  if bool_init_with_make
231
    fprintf(1,'\n-- Initialisation avec SOM_MAKE ... ')
232
    sMap=som_make(sD_norm, ...
233 234 235 236 237 238 239 240
                  'munits',   nb_neurone, ...
                  'lattice',  lattice, ...
                  'init',     init, ...
                  'tracking', tracking); % creer la carte initiale avec et effectuer un entrainenemt

  else
    if strcmp(init,'randinit')
      fprintf(1,'\n-- Initialisation avec SOM_RANDINIT ... ')
241
      sMap=som_randinit(sD_norm, ...
242 243 244 245 246 247
                        'munits',   nb_neurone, ...
                        'lattice',  lattice, ...
                        'tracking', tracking); % creer la carte initiale

    elseif strcmp(init,'lininit')
      fprintf(1,'\n-- Initialisation avec SOM_LININIT ... ')
248
      sMap=som_lininit(sD_norm, ...
249 250 251 252 253 254 255 256 257 258 259 260
                       'munits',   nb_neurone, ...
                       'lattice',  lattice, ...
                       'tracking', tracking); % creer la carte initiale

    else
      error(sprintf(['\n *** %s error: invalid ''init'' option ''%s'' ***\n', ...
                     '     Shoud be one between { ''lininit'', ''randinit'' } ***\n' ], ...
                    mfilename, init));
    end
    fprintf(1,' <som init END>.\n')
  end
  
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
  % bool_rad=0;
  % bool_trainlen=0;
  % if ~isempty(varargin)
  % 
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i})
  %       switch varargin{i}
  %         case 'radius'
  %           bool_rad=1;
  %           loc_rad=i;
  %           rad=varargin{loc_rad+1};
  %           i=i+1;
  %         case 'trainlen' 
  %           bool_trainlen=1;
  %           loc_trainlen=i;
  %           trlen=varargin{loc_trainlen+1};
  %           i=i+1;
  %         otherwise
  %           i=i+1;
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end

287
  if bool_pre_training
288 289
    pretrain_tracking = tracking;
    %pretrain_tracking = 1;
290
    
291 292 293
    % batchtrain avec radius ...
    if (bool_rad && ~bool_trainlen)
      fprintf(1,'\n-- BATCHTRAIN initial avec radius ... ')
294
      if pretrain_tracking, fprintf(1,'\n'); end
295 296 297
      j=1;
      while j<length(rad)
        
298
        sMap=som_batchtrain(sMap, sD_norm, ...
299 300
                            'radius',[rad(j) rad(j+1)], ...
                            'tracking',pretrain_tracking);
301 302
        j=j+1;
        
303 304
      end
    end
305 306 307
    % batchtrain avec trainlen ...
    if (~bool_rad && bool_trainlen) 
      fprintf(1,'\n-- BATCHTRAIN initial avec trainlen ... ')
308
      if pretrain_tracking, fprintf(1,'\n'); end
309 310 311
      j=1;
      while j<=length(trlen)
        
312
        sMap=som_batchtrain(sMap, sD_norm, ...
313 314
                            'trainlen',trlen(j), ...
                            'tracking',pretrain_tracking);
315 316 317 318 319
        j=j+1;
        
      end
    end
    % batchtrain avec radius et trainlen             
320
    if (bool_rad && bool_trainlen)
321
      fprintf(1,'\n-- BATCHTRAIN initial avec radius et trainlen ... \n')
322
      if pretrain_tracking, fprintf(1,'\n'); end
323 324
      if length(rad)==length(trlen)+1
        
325 326 327
        j=1;
        while j<length(rad)
          
328
          sMap=som_batchtrain(sMap, sD_norm, ...
329 330 331
                              'radius',[rad(j) rad(j+1)], ...
                              'trainlen',trlen(j), ...
                              'tracking',pretrain_tracking);
332 333 334
          j=j+1;
          
        end
335 336
      else
        error('vecteur radius doit avoir un element en plus que le vecteur trainlen ')
337
      end
338
    end
339 340
    sMapPT = sMap;

341 342 343 344
    current_perf = som_distortion(sMap,sD_norm);
    fprintf(1,'--> som_distortion apres entrainement initiale = %s\n', num2str(current_perf));
    
  else
345
    fprintf(1,'** batchtrain initial (pre-training) non active **\n')
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
  end
  % end
  
  % %S2-SOM
  % bool_2ssom=0;
  % bool_DimData=0;
  % bool_lambda=0;
  % bool_eta=0;
  %
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i}) 
  %       switch varargin{i} 
  %      
  %         case 'S2-SOM'
  %           disp('** S2-SOM Active **');
  %           bool_2ssom=1;
  %           i=i+1;
  %           %mettre en bloc
  %         case 'DimData'
  %           i=i+1;
  %           DimData=varargin{i};
  %           for di=1:length(DimData)
  %             DimBloc(di).Dim=DimData(di);
  %           end
  %           bool_DimData=1;
  %         case 'lambda' 
  %           i=i+1; 
  %           lambda=varargin{i};
  %           if length(lambda) < 1
  %             error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_lambda=1;
  %         case 'eta' 
  %           i=i+1; eta=varargin{i};
  %           if length(eta) < 1
  %             error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_eta=1;
  %         otherwise
  %           i=i+1;
  %
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end
  
  if (bool_2ssom)
396
    if (bool_lambda && bool_eta && bool_DimData)
397 398 399 400 401 402 403
      
      best_i   = 0;
      best_j   = 0;
      bestperf = inf;
      
      i_train = 1;
      n_train = length(lambda)*length(eta);
404
      
405 406 407 408 409 410 411
      if ~bool_rad_2s_som
        rad_2s_som =  [rad(round(length(rad)/2)) ...
                          rad((round(length(rad)/2))+1)];
      end
      if ~bool_trlen_2s_som
        trlen_2s_som = trlen(round(length(trlen)/2));
      end
412

413 414 415
      fprintf(1,[ '\n-- batchtrainRTOM loop for %d lambda and %d eta values:\n', ... 
                  '-- ------------------------------------------------------------------\n' ], ...
              length(lambda), length(eta));
416
      if tracking > 1,
417 418
        fprintf(1,'   ... trainlen_2s_som ... %s\n', num2str(trlen_2s_som))
        fprintf(1,'   ... radius_2s_som ..... [%s]\n', join(string(rad_2s_som),', '))
419
      end
420 421
      for i=1:length(lambda)
        for j=1:length(eta)
422
          fprintf(1,'-- batchtrainRTOM (%d/%d) with lambda=%s and eta=%s ... ',i_train, ...
423 424
                  n_train, num2str(lambda(i)),num2str(eta(j)));
          if tracking, fprintf(1,'\n'); end
425
          
426 427
          [Result(i,j).sMap Result(i,j).bmus Result(i,j).Alpha Result(i,j).Beta] = som_batchtrainRTOM( ...
              sMap, sD_norm, ...
428 429 430 431 432 433 434 435 436 437 438
              'TypeAlgo', '2SSOM', ...
              'DimData',  DimData, ...
              'DimBloc',  DimBloc, ...
              'lambda',   lambda(i), ...
              'eta',      eta(j), ...
              'radius',   rad_2s_som, ...
              'trainlen', trlen_2s_som, ...
              'tracking', tracking);
          
          Result(i,j).lambda = lambda(i);
          Result(i,j).eta    = eta(j);
439
          
440 441 442 443 444 445 446 447 448 449 450 451 452 453
          current_perf = som_distortion(Result(i,j).sMap,sD_norm);
          fprintf(1,'   --> som_distortion=%s\n', num2str(current_perf));
          %  end
          %end
          % best_i=0;
          % best_j=0;
          % bestperf=inf;
          % for i=1:length(lambda)
          %   for j=1:length(eta)
          %         
          Result(i,j).Perf = current_perf;
          if Result(i,j).Perf < bestperf
            best_i = i;
            best_j = j;
454
          end
455 456
          
          i_train = i_train + 1;
457 458 459
        end
      end
      
460
      sMap = Result(best_i,best_j).sMap;
461 462
      
    else
463
      error('manque de parametre')
464
    end
465 466
  elseif (bool_lambda || bool_eta || bool_DimData)
    error('mentionnez si vous voulez S2-SOM')
467 468
  end
  
469
  % end
470
  
471 472 473 474 475 476 477 478 479 480 481
  clear St
  if (bool_2ssom)
    % si 2S-SOM alors best perf sMap
    [BestPerf,iBest] = min(cell2mat({Result.Perf}));
    
    sMap = Result(iBest).sMap;
  else
    % sinon, si pas 2S-SOM
    sMap = sMapPT;
  end
  
482 483
  % denormalisation de la Map
  if bool_norm
484 485 486 487 488 489 490 491 492 493 494 495 496
    sMap_dnrm = som_denormalize(sMap,sD_norm.comp_norm);
  else
    sMap_dnrm = sMap;
  end
  
  if bool_return_struct
    % Si retour STRUCT
    St.sMap     = sMap;
    if bool_norm
      St.sMap_denorm = sMap_dnrm;
    end
    
    if (bool_2ssom)
497 498
      St.lambda = Result(iBest).lambda;
      St.eta    = Result(iBest).eta;
499 500 501 502 503 504 505 506 507 508 509 510 511 512
      St.bmus   = Result(iBest).bmus;
      St.Alpha  = Result(iBest).Alpha;
      St.Beta   = Result(iBest).Beta;
      
      St.sMapPT   = sMapPT;
      if bool_norm
        St.sMapPT_denorm = som_denormalize(sMapPT,sD_norm.comp_norm);
      end
      
      St.Result   = Result;
      St.iBest    = iBest;
    end
    
    StsMap = St; % variable de retour
513
  else
514 515 516 517 518
    % Sinon, alors retour variables  ... (sMap sMap_denorm Result)
    StsMap      = sMap;
    sMap_denorm = sMap_dnrm;
    Resultout   = Result;
    sMapPTout   = sMapPT;
519 520 521
  end
  
  return