Docker-in-Docker (DinD) capabilities of public runners deactivated. More info

learn_2s_som.m 15.5 KB
Newer Older
1
function [StsMap sMap_denorm Resultout sMapPTout] = learn_2s_som(A,nb_neurone,varargin)
2 3
% Cree la carte SOM ou S2-SOM Pour donnees cachees
%
4 5 6 7 8 9
% Usage:
%
%    [sMap, sMap_denorm, Result, sMapPT] = learn_2s_som(A, nb_neurone, <OPTIOS>)
%
%    St = learn_2s_som(A,nb_neurone, '-struct', <OPTIOS>)
%
10 11 12 13 14
% En entree obligatoire
%
%   A: les donnees cachees
%   nb_neurone: Nombre de neurones 
%
15 16 17 18
% En option, elles sont specifies par couples de valeurs, p. exemple:
%
%         'radius', [ 5 1 ], 'trainlen', 20, 'tracking', 1, ...
%  
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
%
%   radius: en forme de vecteur, chaque deux elements qui ce suivent constitue
%           une temperature [i..i+1],[i+1..i+2],....
%   trainlen: en forme de vecteur: chaque element constitue une itération de
%           l'entraienement. NB:vecteur radius doit avoir un element en plus
%           que le vecteur trainlen.
%   tracking: pour visualiser l'apprentissage.
%
%   'S2-SOM': pour faire l'apprentissage avec S2-SOM. Si 'S2-SOM' est
%           specifie alors il faut d'autres parametres:
%
%   DimData: vecteur contenant la dimention de chaque bloc.
%   lambda: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           blocs.
%   eta: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           variables.
%
% En sortie
%
38
%   sMap: La carte SOM ou S2-SOM au point de meilleur "Perf".
39 40 41
%
%   sMap_denorm: La carte SOM ou S2-SOM, denormalisee.
%
42 43
%   iBest = indice dans Result de la carte sMap retournee.
%
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
%   Result: structure (vecteur) avec les sorties ou resultats de chaque cas
%           entraine (avec une paire distincte de la combinaison entre lambda
%           et eta): sMap, bmus, Alpha, Beta, Perf.
%
%           Champs de Result:
%              sMap: La carte SOM ou S2-SOM du cas.
%              bmus: Bmus (best matching units) sur toute la zone.
%              Alpha: Coefficients Alpha multipliant les groupes.
%              Beta: Coefficients Alpha multipliant les variables au sans de la
%                  Carte Topologique.
%              Perf: parametre "distortion measure for the map", calcule par
%                  la fonction som_distortion.
%
%   (bmus_pixel:(best matching units) par pixel.)
% Detailed explanation goes here
  
  
61 62 63 64
% Valeurs par defaut
  tracking       = 0;
  init           = 'lininit';
  lattice        = 'rect';
65 66
    
  % flags et variables associees
67 68 69 70 71 72 73 74 75 76 77
  bool_verbose        = false;
  bool_return_struct  = false;
  bool_norm           = false; type_norm     = 'simple';
  bool_rad            = false; rad           = [5 1];
  bool_trainlen       = false; trlen         = 20;
  bool_rad_2s_som     = false; rad_2s_som    = [];
  bool_trlen_2s_som   = false; trlen_2s_som  = [];
  bool_2ssom          = false;
  bool_DimData        =  false; DimData       = [size(A,2)];
  bool_lambda         = false; lambda        = 1;
  bool_eta            = false; eta           = 1000;
78

79
  Result              = [];
80
  
81
  bool_init_with_make = true;
82
  bool_pre_training   = true;
83 84 85 86 87
  
  %recuperer les donnees
  data.data=A;
  label=[1:size(data.data,2)];
  
88
  %Labelise les donnees (affectation apres boucle d'arguments (selon la valeur de DimBloc)
89 90 91 92
  ListVar={};
  
  data_casename='simulation';
  
93 94 95 96 97
  % --- CM pour ajouter les arguments 'data_name' et 'comp_names'
  i=1;
  while (i<=length(varargin))
    if ischar(varargin{i})
      switch varargin{i},
98
        case { 'verbose', '-verbose' },
99
          bool_verbose = true;
100 101
        case { 'returnstruct', 'return-struct', 'struct', '-return-struct', '-struct' },
          bool_return_struct = true;
102
        case { 'data_name' },
103
          data_casename = varargin{i+1}; i=i+1;
104
        case { 'comp_names' },
105
          ListVar = varargin{i+1}; i=i+1;
106
        case { 'norm' },
107
          bool_norm = true; 
108 109 110 111 112 113 114 115
          type_norm = varargin{i+1}; i=i+1;
        case { 'init' },
          init = varargin{i+1}; i=i+1;
        case { 'tracking' },
          tracking = varargin{i+1}; i=i+1;
        case { 'lattice' },
          lattice = varargin{i+1}; i=i+1;
        case 'radius'
116
          bool_rad = true;
117 118
          rad = varargin{i+1}; i=i+1;
        case 'trainlen' 
119
          bool_trainlen = true;
120
          trlen = varargin{i+1}; i=i+1;
121 122 123 124 125 126
        case 'radius-2s-som'
          bool_rad_2s_som = true;
          rad_2s_som = varargin{i+1}; i=i+1;
        case 'trainlen-2s-som' 
          bool_trlen_2s_som = true;
          trlen_2s_som = varargin{i+1}; i=i+1;
127 128
        case 'S2-SOM'
          disp('** S2-SOM Active **');
129
          bool_2ssom = true;
130 131 132 133 134
        case 'DimData'
          DimData = varargin{i+1}; i=i+1;
          for di=1:length(DimData)
            DimBloc(di).Dim = DimData(di);
          end
135
          bool_DimData = true;
136 137 138 139 140
        case 'lambda' 
          lambda=varargin{i+1}; i=i+1;
          if length(lambda) < 1
            error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
141
          bool_lambda = true;
142 143 144 145 146
        case 'eta' 
          eta = varargin{i+1}; i=i+1;
          if length(eta) < 1
            error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
147
          bool_eta = true;
148 149 150 151
        case 'ini-with-make'
          bool_init_with_make = true;
        case 'no-ini-with-make'
          bool_init_with_make = false;
152 153 154 155
        case 'pre-training'
          bool_pre_training   = true;
        case 'no-pre-training'
          bool_pre_training   = false;
156 157 158
        otherwise
          error(sprintf(' *** %s error: argument(%d) ''%s'' inconnu ***\n', ...
                        mfilename, i, varargin{i}));
159
      end
160 161 162
    else
      error(sprintf(' *** %s error: argument non-string inattendu (en %d-iemme position) ***\n', ...
                    mfilename, i));
163 164 165
    end
    i=i+1;
  end
166 167 168 169 170 171 172 173 174 175 176 177 178
  
  if isempty(ListVar),
    kVar = 1;
    for iG = 1:length(DimData),
      szG = DimData(iG);
      for l = 1:szG
        ListVar{kVar,1} = sprintf('Gr%dVar%d', iG, l);
        kVar = kVar + 1;
      end
    end
  end

  data.colheaders = ListVar;
179

180
  sD = som_data_struct(data.data,'name', data_casename,'comp_names', upper(ListVar));
181 182 183 184 185 186 187 188
  % i=1;
  % while (i<=length(varargin) && bool_norm==0)
  %   if strcmp(varargin{i},'norm')
  %     bool_norm=1; 
  %     type_norm=varargin{i+1};
  %   end
  %   i=i+1;
  % end
189 190 191 192 193 194 195 196 197
  
  %normalisation des donnees
  if bool_norm
    fprintf(1,'\n-- Normalisation des donnees selon ''%s'' ...\n', type_norm);
    if strcmp(type_norm,'simple')
      sD_norm=som_normalize(sD);
    else
      sD_norm=som_normalize(sD,type_norm);
    end
198
  else
199
    fprintf(1,'\n** Pas de normalisation des donnees **\n');
200 201 202 203
    sD_norm = sD;
  end
  
  
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if strcmp(varargin{i},'init')
  %       init=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'tracking')
  %       tracking=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'lattice')
  %       lattice=varargin{i+1};
  %     end
  %     i=i+1;
  %   end
  % end
219

220 221 222 223 224
  fprintf(1,[ '\n-- ------------------------------------------------------------------\n', ...
              '-- New 2S-SOMTraining function:\n', ...
              '--   %s (''%s'', ''%s'', ''%s'', ... )\n', ...
              '-- ------------------------------------------------------------------\n' ], ...
          mfilename, init, lattice, data_casename);
225
    
226
  %SOM initialisation
227
  if bool_init_with_make
228
    fprintf(1,'\n-- Initialisation avec SOM_MAKE ... ')
229
    sMap=som_make(sD_norm, ...
230 231 232 233 234 235 236 237
                  'munits',   nb_neurone, ...
                  'lattice',  lattice, ...
                  'init',     init, ...
                  'tracking', tracking); % creer la carte initiale avec et effectuer un entrainenemt

  else
    if strcmp(init,'randinit')
      fprintf(1,'\n-- Initialisation avec SOM_RANDINIT ... ')
238
      sMap=som_randinit(sD_norm, ...
239 240 241 242 243 244
                        'munits',   nb_neurone, ...
                        'lattice',  lattice, ...
                        'tracking', tracking); % creer la carte initiale

    elseif strcmp(init,'lininit')
      fprintf(1,'\n-- Initialisation avec SOM_LININIT ... ')
245
      sMap=som_lininit(sD_norm, ...
246 247 248 249 250 251 252 253 254 255 256 257
                       'munits',   nb_neurone, ...
                       'lattice',  lattice, ...
                       'tracking', tracking); % creer la carte initiale

    else
      error(sprintf(['\n *** %s error: invalid ''init'' option ''%s'' ***\n', ...
                     '     Shoud be one between { ''lininit'', ''randinit'' } ***\n' ], ...
                    mfilename, init));
    end
    fprintf(1,' <som init END>.\n')
  end
  
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
  % bool_rad=0;
  % bool_trainlen=0;
  % if ~isempty(varargin)
  % 
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i})
  %       switch varargin{i}
  %         case 'radius'
  %           bool_rad=1;
  %           loc_rad=i;
  %           rad=varargin{loc_rad+1};
  %           i=i+1;
  %         case 'trainlen' 
  %           bool_trainlen=1;
  %           loc_trainlen=i;
  %           trlen=varargin{loc_trainlen+1};
  %           i=i+1;
  %         otherwise
  %           i=i+1;
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end

284
  if bool_pre_training
285 286
    pretrain_tracking = tracking;
    %pretrain_tracking = 1;
287
    
288 289 290
    % batchtrain avec radius ...
    if (bool_rad && ~bool_trainlen)
      fprintf(1,'\n-- BATCHTRAIN initial avec radius ... ')
291
      if pretrain_tracking, fprintf(1,'\n'); end
292 293 294
      j=1;
      while j<length(rad)
        
295
        sMap=som_batchtrain(sMap, sD_norm, ...
296 297
                            'radius',[rad(j) rad(j+1)], ...
                            'tracking',pretrain_tracking);
298 299
        j=j+1;
        
300 301
      end
    end
302 303 304
    % batchtrain avec trainlen ...
    if (~bool_rad && bool_trainlen) 
      fprintf(1,'\n-- BATCHTRAIN initial avec trainlen ... ')
305
      if pretrain_tracking, fprintf(1,'\n'); end
306 307 308
      j=1;
      while j<=length(trlen)
        
309
        sMap=som_batchtrain(sMap, sD_norm, ...
310 311
                            'trainlen',trlen(j), ...
                            'tracking',pretrain_tracking);
312 313 314 315 316
        j=j+1;
        
      end
    end
    % batchtrain avec radius et trainlen             
317
    if (bool_rad && bool_trainlen)
318
      fprintf(1,'\n-- BATCHTRAIN initial avec radius et trainlen ... \n')
319
      if pretrain_tracking, fprintf(1,'\n'); end
320 321
      if length(rad)==length(trlen)+1
        
322 323 324
        j=1;
        while j<length(rad)
          
325
          sMap=som_batchtrain(sMap, sD_norm, ...
326 327 328
                              'radius',[rad(j) rad(j+1)], ...
                              'trainlen',trlen(j), ...
                              'tracking',pretrain_tracking);
329 330 331
          j=j+1;
          
        end
332 333
      else
        error('vecteur radius doit avoir un element en plus que le vecteur trainlen ')
334
      end
335
    end
336 337
    sMapPT = sMap;

338 339 340 341
    current_perf = som_distortion(sMap,sD_norm);
    fprintf(1,'--> som_distortion apres entrainement initiale = %s\n', num2str(current_perf));
    
  else
342
    fprintf(1,'** batchtrain initial (pre-training) non active **\n')
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
  end
  % end
  
  % %S2-SOM
  % bool_2ssom=0;
  % bool_DimData=0;
  % bool_lambda=0;
  % bool_eta=0;
  %
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i}) 
  %       switch varargin{i} 
  %      
  %         case 'S2-SOM'
  %           disp('** S2-SOM Active **');
  %           bool_2ssom=1;
  %           i=i+1;
  %           %mettre en bloc
  %         case 'DimData'
  %           i=i+1;
  %           DimData=varargin{i};
  %           for di=1:length(DimData)
  %             DimBloc(di).Dim=DimData(di);
  %           end
  %           bool_DimData=1;
  %         case 'lambda' 
  %           i=i+1; 
  %           lambda=varargin{i};
  %           if length(lambda) < 1
  %             error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_lambda=1;
  %         case 'eta' 
  %           i=i+1; eta=varargin{i};
  %           if length(eta) < 1
  %             error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_eta=1;
  %         otherwise
  %           i=i+1;
  %
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end
  
  if (bool_2ssom)
393
    if (bool_lambda && bool_eta && bool_DimData)
394 395 396 397 398 399 400
      
      best_i   = 0;
      best_j   = 0;
      bestperf = inf;
      
      i_train = 1;
      n_train = length(lambda)*length(eta);
401
      
402 403 404 405 406 407 408
      if ~bool_rad_2s_som
        rad_2s_som =  [rad(round(length(rad)/2)) ...
                          rad((round(length(rad)/2))+1)];
      end
      if ~bool_trlen_2s_som
        trlen_2s_som = trlen(round(length(trlen)/2));
      end
409

410 411 412
      fprintf(1,[ '\n-- batchtrainRTOM loop for %d lambda and %d eta values:\n', ... 
                  '-- ------------------------------------------------------------------\n' ], ...
              length(lambda), length(eta));
413
      if tracking > 1,
414 415
        fprintf(1,'   ... trainlen_2s_som ... %s\n', num2str(trlen_2s_som))
        fprintf(1,'   ... radius_2s_som ..... [%s]\n', join(string(rad_2s_som),', '))
416
      end
417 418
      for i=1:length(lambda)
        for j=1:length(eta)
419
          fprintf(1,'-- batchtrainRTOM (%d/%d) with lambda=%s and eta=%s ... ',i_train, ...
420 421
                  n_train, num2str(lambda(i)),num2str(eta(j)));
          if tracking, fprintf(1,'\n'); end
422
          
423 424
          [Result(i,j).sMap Result(i,j).bmus Result(i,j).Alpha Result(i,j).Beta] = som_batchtrainRTOM( ...
              sMap, sD_norm, ...
425 426 427 428 429 430 431 432 433 434 435
              'TypeAlgo', '2SSOM', ...
              'DimData',  DimData, ...
              'DimBloc',  DimBloc, ...
              'lambda',   lambda(i), ...
              'eta',      eta(j), ...
              'radius',   rad_2s_som, ...
              'trainlen', trlen_2s_som, ...
              'tracking', tracking);
          
          Result(i,j).lambda = lambda(i);
          Result(i,j).eta    = eta(j);
436
          
437 438 439 440 441 442 443 444 445 446 447 448 449 450
          current_perf = som_distortion(Result(i,j).sMap,sD_norm);
          fprintf(1,'   --> som_distortion=%s\n', num2str(current_perf));
          %  end
          %end
          % best_i=0;
          % best_j=0;
          % bestperf=inf;
          % for i=1:length(lambda)
          %   for j=1:length(eta)
          %         
          Result(i,j).Perf = current_perf;
          if Result(i,j).Perf < bestperf
            best_i = i;
            best_j = j;
451
          end
452 453
          
          i_train = i_train + 1;
454 455 456
        end
      end
      
457
      sMap = Result(best_i,best_j).sMap;
458 459
      
    else
460
      error('manque de parametre')
461
    end
462 463
  elseif (bool_lambda || bool_eta || bool_DimData)
    error('mentionnez si vous voulez S2-SOM')
464 465
  end
  
466
  % end
467
  
468 469 470 471 472 473 474 475 476 477 478
  clear St
  if (bool_2ssom)
    % si 2S-SOM alors best perf sMap
    [BestPerf,iBest] = min(cell2mat({Result.Perf}));
    
    sMap = Result(iBest).sMap;
  else
    % sinon, si pas 2S-SOM
    sMap = sMapPT;
  end
  
479 480
  % denormalisation de la Map
  if bool_norm
481 482 483 484 485 486 487 488 489 490 491 492 493
    sMap_dnrm = som_denormalize(sMap,sD_norm.comp_norm);
  else
    sMap_dnrm = sMap;
  end
  
  if bool_return_struct
    % Si retour STRUCT
    St.sMap     = sMap;
    if bool_norm
      St.sMap_denorm = sMap_dnrm;
    end
    
    if (bool_2ssom)
494 495
      St.lambda = Result(iBest).lambda;
      St.eta    = Result(iBest).eta;
496 497 498 499 500 501 502 503 504 505 506 507 508 509
      St.bmus   = Result(iBest).bmus;
      St.Alpha  = Result(iBest).Alpha;
      St.Beta   = Result(iBest).Beta;
      
      St.sMapPT   = sMapPT;
      if bool_norm
        St.sMapPT_denorm = som_denormalize(sMapPT,sD_norm.comp_norm);
      end
      
      St.Result   = Result;
      St.iBest    = iBest;
    end
    
    StsMap = St; % variable de retour
510
  else
511 512 513 514 515
    % Sinon, alors retour variables  ... (sMap sMap_denorm Result)
    StsMap      = sMap;
    sMap_denorm = sMap_dnrm;
    Resultout   = Result;
    sMapPTout   = sMapPT;
516 517 518
  end
  
  return