Docker-in-Docker (DinD) capabilities of public runners deactivated. More info

learn_2s_som.m 15.3 KB
Newer Older
1
function [StsMap sMap_denorm Resultout sMapPTout] = learn_2s_som(A,nb_neurone,varargin)
2 3
% Cree la carte SOM ou S2-SOM Pour donnees cachees
%
4 5 6 7 8 9
% Usage:
%
%    [sMap, sMap_denorm, Result, sMapPT] = learn_2s_som(A, nb_neurone, <OPTIOS>)
%
%    St = learn_2s_som(A,nb_neurone, '-struct', <OPTIOS>)
%
10 11 12 13 14
% En entree obligatoire
%
%   A: les donnees cachees
%   nb_neurone: Nombre de neurones 
%
15 16 17 18
% En option, elles sont specifies par couples de valeurs, p. exemple:
%
%         'radius', [ 5 1 ], 'trainlen', 20, 'tracking', 1, ...
%  
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
%
%   radius: en forme de vecteur, chaque deux elements qui ce suivent constitue
%           une temperature [i..i+1],[i+1..i+2],....
%   trainlen: en forme de vecteur: chaque element constitue une itération de
%           l'entraienement. NB:vecteur radius doit avoir un element en plus
%           que le vecteur trainlen.
%   tracking: pour visualiser l'apprentissage.
%
%   'S2-SOM': pour faire l'apprentissage avec S2-SOM. Si 'S2-SOM' est
%           specifie alors il faut d'autres parametres:
%
%   DimData: vecteur contenant la dimention de chaque bloc.
%   lambda: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           blocs.
%   eta: vecteur, c'est un hyperparametre pour calculer le poids sur les
%           variables.
%
% En sortie
%
38
%   sMap: La carte SOM ou S2-SOM au point de meilleur "Perf".
39 40 41
%
%   sMap_denorm: La carte SOM ou S2-SOM, denormalisee.
%
42 43
%   iBest = indice dans Result de la carte sMap retournee.
%
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
%   Result: structure (vecteur) avec les sorties ou resultats de chaque cas
%           entraine (avec une paire distincte de la combinaison entre lambda
%           et eta): sMap, bmus, Alpha, Beta, Perf.
%
%           Champs de Result:
%              sMap: La carte SOM ou S2-SOM du cas.
%              bmus: Bmus (best matching units) sur toute la zone.
%              Alpha: Coefficients Alpha multipliant les groupes.
%              Beta: Coefficients Alpha multipliant les variables au sans de la
%                  Carte Topologique.
%              Perf: parametre "distortion measure for the map", calcule par
%                  la fonction som_distortion.
%
%   (bmus_pixel:(best matching units) par pixel.)
% Detailed explanation goes here
  
  
61 62 63 64
% Valeurs par defaut
  tracking       = 0;
  init           = 'lininit';
  lattice        = 'rect';
65 66
    
  % flags et variables associees
67 68 69 70 71 72 73 74 75 76 77
  bool_verbose        = false;
  bool_return_struct  = false;
  bool_norm           = false; type_norm     = 'simple';
  bool_rad            = false; rad           = [5 1];
  bool_trainlen       = false; trlen         = 20;
  bool_rad_2s_som     = false; rad_2s_som    = [];
  bool_trlen_2s_som   = false; trlen_2s_som  = [];
  bool_2ssom          = false;
  bool_DimData        =  false; DimData       = [size(A,2)];
  bool_lambda         = false; lambda        = 1;
  bool_eta            = false; eta           = 1000;
78

79
  Result              = [];
80
  
81
  bool_init_with_make = true;
82
  bool_pre_training   = true;
83 84 85 86 87
  
  %recuperer les donnees
  data.data=A;
  label=[1:size(data.data,2)];
  
88
  %Labelise les donnees (affectation apres boucle d'arguments (selon la valeur de DimBloc)
89 90 91 92
  ListVar={};
  
  data_casename='simulation';
  
93 94 95 96 97
  % --- CM pour ajouter les arguments 'data_name' et 'comp_names'
  i=1;
  while (i<=length(varargin))
    if ischar(varargin{i})
      switch varargin{i},
98
        case { 'verbose', '-verbose' },
99
          bool_verbose = true;
100 101
        case { 'returnstruct', 'return-struct', 'struct', '-return-struct', '-struct' },
          bool_return_struct = true;
102
        case { 'data_name' },
103
          data_casename = varargin{i+1}; i=i+1;
104
        case { 'comp_names' },
105
          ListVar = varargin{i+1}; i=i+1;
106
        case { 'norm' },
107
          bool_norm = true; 
108 109 110 111 112 113 114 115
          type_norm = varargin{i+1}; i=i+1;
        case { 'init' },
          init = varargin{i+1}; i=i+1;
        case { 'tracking' },
          tracking = varargin{i+1}; i=i+1;
        case { 'lattice' },
          lattice = varargin{i+1}; i=i+1;
        case 'radius'
116
          bool_rad = true;
117 118
          rad = varargin{i+1}; i=i+1;
        case 'trainlen' 
119
          bool_trainlen = true;
120
          trlen = varargin{i+1}; i=i+1;
121 122 123 124 125 126
        case 'radius-2s-som'
          bool_rad_2s_som = true;
          rad_2s_som = varargin{i+1}; i=i+1;
        case 'trainlen-2s-som' 
          bool_trlen_2s_som = true;
          trlen_2s_som = varargin{i+1}; i=i+1;
127 128
        case 'S2-SOM'
          disp('** S2-SOM Active **');
129
          bool_2ssom = true;
130 131 132 133 134
        case 'DimData'
          DimData = varargin{i+1}; i=i+1;
          for di=1:length(DimData)
            DimBloc(di).Dim = DimData(di);
          end
135
          bool_DimData = true;
136 137 138 139 140
        case 'lambda' 
          lambda=varargin{i+1}; i=i+1;
          if length(lambda) < 1
            error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
141
          bool_lambda = true;
142 143 144 145 146
        case 'eta' 
          eta = varargin{i+1}; i=i+1;
          if length(eta) < 1
            error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
          end
147
          bool_eta = true;
148 149 150 151
        case 'ini-with-make'
          bool_init_with_make = true;
        case 'no-ini-with-make'
          bool_init_with_make = false;
152 153 154 155
        case 'pre-training'
          bool_pre_training   = true;
        case 'no-pre-training'
          bool_pre_training   = false;
156 157 158
        otherwise
          error(sprintf(' *** %s error: argument(%d) ''%s'' inconnu ***\n', ...
                        mfilename, i, varargin{i}));
159
      end
160 161 162
    else
      error(sprintf(' *** %s error: argument non-string inattendu (en %d-iemme position) ***\n', ...
                    mfilename, i));
163 164 165
    end
    i=i+1;
  end
166 167 168 169 170 171 172 173 174 175 176 177 178
  
  if isempty(ListVar),
    kVar = 1;
    for iG = 1:length(DimData),
      szG = DimData(iG);
      for l = 1:szG
        ListVar{kVar,1} = sprintf('Gr%dVar%d', iG, l);
        kVar = kVar + 1;
      end
    end
  end

  data.colheaders = ListVar;
179

180
  sD = som_data_struct(data.data,'name', data_casename,'comp_names', upper(ListVar));
181 182 183 184 185 186 187 188
  % i=1;
  % while (i<=length(varargin) && bool_norm==0)
  %   if strcmp(varargin{i},'norm')
  %     bool_norm=1; 
  %     type_norm=varargin{i+1};
  %   end
  %   i=i+1;
  % end
189 190 191 192 193 194 195 196 197
  
  %normalisation des donnees
  if bool_norm
    fprintf(1,'\n-- Normalisation des donnees selon ''%s'' ...\n', type_norm);
    if strcmp(type_norm,'simple')
      sD_norm=som_normalize(sD);
    else
      sD_norm=som_normalize(sD,type_norm);
    end
198
  else
199
    fprintf(1,'\n** Pas de normalisation des donnees **\n');
200 201 202 203
    sD_norm = sD;
  end
  
  
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if strcmp(varargin{i},'init')
  %       init=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'tracking')
  %       tracking=varargin{i+1};
  %     end
  %     if strcmp(varargin{i},'lattice')
  %       lattice=varargin{i+1};
  %     end
  %     i=i+1;
  %   end
  % end
219

220 221 222 223 224
  fprintf(1,[ '\n-- ------------------------------------------------------------------\n', ...
              '-- New 2S-SOMTraining function:\n', ...
              '--   %s (''%s'', ''%s'', ''%s'', ... )\n', ...
              '-- ------------------------------------------------------------------\n' ], ...
          mfilename, init, lattice, data_casename);
225
    
226
  %SOM initialisation
227
  if bool_init_with_make
228
    fprintf(1,'\n-- Initialisation avec SOM_MAKE ... ')
229
    sMap=som_make(sD_norm, ...
230 231 232 233 234 235 236 237
                  'munits',   nb_neurone, ...
                  'lattice',  lattice, ...
                  'init',     init, ...
                  'tracking', tracking); % creer la carte initiale avec et effectuer un entrainenemt

  else
    if strcmp(init,'randinit')
      fprintf(1,'\n-- Initialisation avec SOM_RANDINIT ... ')
238
      sMap=som_randinit(sD_norm, ...
239 240 241 242 243 244
                        'munits',   nb_neurone, ...
                        'lattice',  lattice, ...
                        'tracking', tracking); % creer la carte initiale

    elseif strcmp(init,'lininit')
      fprintf(1,'\n-- Initialisation avec SOM_LININIT ... ')
245
      sMap=som_lininit(sD_norm, ...
246 247 248 249 250 251 252 253 254 255 256 257
                       'munits',   nb_neurone, ...
                       'lattice',  lattice, ...
                       'tracking', tracking); % creer la carte initiale

    else
      error(sprintf(['\n *** %s error: invalid ''init'' option ''%s'' ***\n', ...
                     '     Shoud be one between { ''lininit'', ''randinit'' } ***\n' ], ...
                    mfilename, init));
    end
    fprintf(1,' <som init END>.\n')
  end
  
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
  % bool_rad=0;
  % bool_trainlen=0;
  % if ~isempty(varargin)
  % 
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i})
  %       switch varargin{i}
  %         case 'radius'
  %           bool_rad=1;
  %           loc_rad=i;
  %           rad=varargin{loc_rad+1};
  %           i=i+1;
  %         case 'trainlen' 
  %           bool_trainlen=1;
  %           loc_trainlen=i;
  %           trlen=varargin{loc_trainlen+1};
  %           i=i+1;
  %         otherwise
  %           i=i+1;
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end

284
  if bool_pre_training
285 286
    pretrain_tracking = tracking;
    %pretrain_tracking = 1;
287
    
288 289 290
    % batchtrain avec radius ...
    if (bool_rad && ~bool_trainlen)
      fprintf(1,'\n-- BATCHTRAIN initial avec radius ... ')
291
      if pretrain_tracking, fprintf(1,'\n'); end
292 293 294
      j=1;
      while j<length(rad)
        
295
        sMap=som_batchtrain(sMap, sD_norm, ...
296 297
                            'radius',[rad(j) rad(j+1)], ...
                            'tracking',pretrain_tracking);
298 299
        j=j+1;
        
300 301
      end
    end
302 303 304
    % batchtrain avec trainlen ...
    if (~bool_rad && bool_trainlen) 
      fprintf(1,'\n-- BATCHTRAIN initial avec trainlen ... ')
305
      if pretrain_tracking, fprintf(1,'\n'); end
306 307 308
      j=1;
      while j<=length(trlen)
        
309
        sMap=som_batchtrain(sMap, sD_norm, ...
310 311
                            'trainlen',trlen(j), ...
                            'tracking',pretrain_tracking);
312 313 314 315 316
        j=j+1;
        
      end
    end
    % batchtrain avec radius et trainlen             
317
    if (bool_rad && bool_trainlen)
318
      fprintf(1,'\n-- BATCHTRAIN initial avec radius et trainlen ... \n')
319
      if pretrain_tracking, fprintf(1,'\n'); end
320 321
      if length(rad)==length(trlen)+1
        
322 323 324
        j=1;
        while j<length(rad)
          
325
          sMap=som_batchtrain(sMap, sD_norm, ...
326 327 328
                              'radius',[rad(j) rad(j+1)], ...
                              'trainlen',trlen(j), ...
                              'tracking',pretrain_tracking);
329 330 331
          j=j+1;
          
        end
332 333
      else
        error('vecteur radius doit avoir un element en plus que le vecteur trainlen ')
334
      end
335
    end
336 337
    sMapPT = sMap;

338 339 340 341
    current_perf = som_distortion(sMap,sD_norm);
    fprintf(1,'--> som_distortion apres entrainement initiale = %s\n', num2str(current_perf));
    
  else
342
    fprintf(1,'** batchtrain initial (pre-training) non active **\n')
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
  end
  % end
  
  % %S2-SOM
  % bool_2ssom=0;
  % bool_DimData=0;
  % bool_lambda=0;
  % bool_eta=0;
  %
  % if ~isempty(varargin)
  %   i=1;
  %   while i<=length(varargin)
  %     if ischar(varargin{i}) 
  %       switch varargin{i} 
  %      
  %         case 'S2-SOM'
  %           disp('** S2-SOM Active **');
  %           bool_2ssom=1;
  %           i=i+1;
  %           %mettre en bloc
  %         case 'DimData'
  %           i=i+1;
  %           DimData=varargin{i};
  %           for di=1:length(DimData)
  %             DimBloc(di).Dim=DimData(di);
  %           end
  %           bool_DimData=1;
  %         case 'lambda' 
  %           i=i+1; 
  %           lambda=varargin{i};
  %           if length(lambda) < 1
  %             error('lambda est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_lambda=1;
  %         case 'eta' 
  %           i=i+1; eta=varargin{i};
  %           if length(eta) < 1
  %             error('eta est de longueur nulle !  Il doit y avoir au moins une valeur')
  %           end
  %           bool_eta=1;
  %         otherwise
  %           i=i+1;
  %
  %       end
  %     else
  %       i=i+1;
  %     end
  %   end
  
  if (bool_2ssom)
393
    if (bool_lambda && bool_eta && bool_DimData)
394 395 396 397 398 399 400
      
      best_i   = 0;
      best_j   = 0;
      bestperf = inf;
      
      i_train = 1;
      n_train = length(lambda)*length(eta);
401
      
402 403 404 405 406 407 408
      if ~bool_rad_2s_som
        rad_2s_som =  [rad(round(length(rad)/2)) ...
                          rad((round(length(rad)/2))+1)];
      end
      if ~bool_trlen_2s_som
        trlen_2s_som = trlen(round(length(trlen)/2));
      end
409

410 411 412
      fprintf(1,[ '\n-- batchtrainRTOM loop for %d lambda and %d eta values:\n', ... 
                  '-- ------------------------------------------------------------------\n' ], ...
              length(lambda), length(eta));
413
      if tracking > 1,
414 415
        fprintf(1,'   ... trainlen_2s_som ... %s\n', num2str(trlen_2s_som))
        fprintf(1,'   ... radius_2s_som ..... [%s]\n', join(string(rad_2s_som),', '))
416
      end
417 418
      for i=1:length(lambda)
        for j=1:length(eta)
419
          fprintf(1,'-- batchtrainRTOM (%d/%d) with lambda=%s and eta=%s ... ',i_train, ...
420 421
                  n_train, num2str(lambda(i)),num2str(eta(j)));
          if tracking, fprintf(1,'\n'); end
422
          
423 424 425 426 427 428 429
          [Result(i,j).sMap Result(i,j).bmus Result(i,j).Alpha Result(i,j).Beta] = som_batchtrainRTOM( ...
              sMap, sD_norm, ...
              'TypeAlgo','2SSOM', ...
              'DimData',DimData, ...
              'DimBloc',DimBloc, ...
              'lambda', lambda(i), ...
              'eta',eta(j), ...
430 431
              'radius',rad_2s_som, ...
              'trainlen',trlen_2s_som, ...
432
              'tracking',tracking);
433
          
434 435 436 437 438 439 440 441 442 443 444 445 446 447
          current_perf = som_distortion(Result(i,j).sMap,sD_norm);
          fprintf(1,'   --> som_distortion=%s\n', num2str(current_perf));
          %  end
          %end
          % best_i=0;
          % best_j=0;
          % bestperf=inf;
          % for i=1:length(lambda)
          %   for j=1:length(eta)
          %         
          Result(i,j).Perf = current_perf;
          if Result(i,j).Perf < bestperf
            best_i = i;
            best_j = j;
448
          end
449 450
          
          i_train = i_train + 1;
451 452 453
        end
      end
      
454
      sMap = Result(best_i,best_j).sMap;
455 456
      
    else
457
      error('manque de parametre')
458
    end
459 460
  elseif (bool_lambda || bool_eta || bool_DimData)
    error('mentionnez si vous voulez S2-SOM')
461 462
  end
  
463
  % end
464
  
465 466 467 468 469 470 471 472 473 474 475
  clear St
  if (bool_2ssom)
    % si 2S-SOM alors best perf sMap
    [BestPerf,iBest] = min(cell2mat({Result.Perf}));
    
    sMap = Result(iBest).sMap;
  else
    % sinon, si pas 2S-SOM
    sMap = sMapPT;
  end
  
476 477
  % denormalisation de la Map
  if bool_norm
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
    sMap_dnrm = som_denormalize(sMap,sD_norm.comp_norm);
  else
    sMap_dnrm = sMap;
  end
  
  if bool_return_struct
    % Si retour STRUCT
    St.sMap     = sMap;
    if bool_norm
      St.sMap_denorm = sMap_dnrm;
    end
    
    if (bool_2ssom)
      St.bmus   = Result(iBest).bmus;
      St.Alpha  = Result(iBest).Alpha;
      St.Beta   = Result(iBest).Beta;
      
      St.sMapPT   = sMapPT;
      if bool_norm
        St.sMapPT_denorm = som_denormalize(sMapPT,sD_norm.comp_norm);
      end
      
      St.Result   = Result;
      St.iBest    = iBest;
    end
    
    StsMap = St; % variable de retour
505
  else
506 507 508 509 510
    % Sinon, alors retour variables  ... (sMap sMap_denorm Result)
    StsMap      = sMap;
    sMap_denorm = sMap_dnrm;
    Resultout   = Result;
    sMapPTout   = sMapPT;
511 512 513
  end
  
  return