
function [Xfin,Pfin,Ftot,VBSSM]=MSBSSM(signalTOT_PP,p,ap,decalage,tol,cyc)

%--------------------------------------------------------------------------
% Function arguments:
%--------------------------------------------------------------------------
% signalTOT_PP: [chans x T x N] preprocessed signal 
% p             Model order
% ap            Prior value for the transition matrix A
% pmax:         Model order selected for the smooth
% decalage:     Defines the decalage between y_t and y_{t-1}
% tol:          Critertion for convergence of the EM algorithm 
% cycl:         Maximum number of iterations of the EM algorithm  
%--------------------------------------------------------------------------

%--------------------------------------------------------------------------
% Outputs:
%--------------------------------------------------------------------------
% Xfin:         Time-varying VAR coefficients (see Cekic et al. (2019))
% Pfin:         Var_cov matrix of the time-varying VAR coefficients
% Ftot:         Free energy value 
% VBSSM:        Structure with all the Bayesian estimated parameters 
%               (see Cekic et al. (2019) for more details).
% p:            The model order 
%--------------------------------------------------------------------------

scaleSTART=min(find(p ~= 0)); scalemax = length(p)-1;
len=size(signalTOT_PP,2); chans=size(signalTOT_PP,1); N=size(signalTOT_PP,3);
k=(sum(p))*chans^2; pmax=p(scalemax+1); p=p(scaleSTART:scalemax+1); 

if nargin <6
    cyc = 10;
    if nargin <5
        tol = 10^(-2);   
        if nargin <4
            decalage = 1;
            if nargin <3
                ap  = 0.9;
            end
        end
    end
end
   
T=len-(pmax*2^(scalemax)+decalage);
fprintf(' MSBSS model estimation\n');

%--------------------------------------------------------------------------
% INITIALISATION: EM Shumway and Stoffer, 1982
%--------------------------------------------------------------------------

[VBSSM_EM]=EM_ShumStoff_HAAR_sezen(signalTOT_PP,1,p,decalage,scaleSTART,scalemax,ap);
x0=VBSSM_EM.x0;
P0=VBSSM_EM.P0;
Q=VBSSM_EM.Q;
iQ=inv(Q);
A=VBSSM_EM.A;
R=VBSSM_EM.R;

%------------------------
% R
%------------------------

% ar
%----------------------------------------------------------
%Prior
% --------------
AR_R=10^(5);
vr=2; % Hyperparameters
apr=0.5*ones(1,chans); bpr=AR_R.^(-2)*ones(1,chans);
iar=bpr; % valeur de dpart
%----------------------------------------------------------
% R
%Prior
%--------------
rp=vr+chans-1;  
Bpp=2*vr*diag(iar); % inv_Wishart

%--------------

%--------------
for j=1:N
    Bp(:,:,j)=Bpp;
end


%------------------------
% Q
%------------------------

%----------------------------------------------------------
% aq
%----------------------------------------------------------
%Prior
% --------------
AR_Q=10^(5);
apq=0.5*ones(1,k); bpq=AR_Q^(-3)*ones(1,k); 
iaq=bpq; % valeur de dpart

% Q
%----------------------------------------------------------
%Prior
% --------------
 qp=0.5;  Dp=diag(iaq);
%--------------


%------------------------
% alpha 
%------------------------


% Prior
%----------------
Aalf=10^(5);
cp=1/2;
bp=Aalf^(-2);
idelt=10^(-1).*ones(1,k); % valeur de dpart

%----------------
%----------------
cap=1/2*ones(1,k);
bap=idelt;

%--------------
%A
%--------------
%Prior
%--------------
ap=eye(k)*ap;

%----------------------------------------------------------
% VALEURS DE DEPART
%----------------------------------------------------------
% Ac: variance, iAc: inverse variance
alf=ones(1,k)*10^(0); % valeur de dpart
Ac=eye(k)*10^(0); iAc=eye(k)*10^(0);


%----------------------------------------------------------
% matrix initialisation
%------------------------
Ftot=0; Xfin=zeros(k,T); 
%--------------------------------------------------------------------------
[X,Y] = matrices_construction(signalTOT_PP);

% Decomposition HAAR de la matrice des historiques

for i=1:chans*N
   HAAR(:,:,i)=atrouwhaar(X(i,:),scalemax);
end

%--------------------------------------------------------------------------

for cycle=1:cyc,
           
        oldf=Ftot; % Rinitialise la free NRJ
        Xfinpre=Xfin;
        
%--------------------------------------------------------------------------
% E-STEP
%--------------------------------------------------------------------------
         
% SIMPLE RTS
%----------------------------------------------
%  [Xfin,Pfin,Pcov,Xpre,Ppre]=vksm_sezen_HAAR(Y,X,x0,P0,A,Q,R,p,pmax,decalage,T,scaleSTART,scalemax,k,HAAR,N);
%--------------------------------------------------------------------------
% UNFIED INFERENCE SMOOTHER
%---------------------------------------------
  UU=Ac*iQ; UU=abs(UU);UA=chol(UU); 
  UAB=UA;
  [Xfin,Pfin,Pcov,Xpre,Ppre]=vksm2HAAR_sezen_lag(Y,X,x0,P0,A,Q,R,p,pmax,decalage,T,scaleSTART,scalemax,k,HAAR,UA,N);

  % MEAN E-STEP ON ALL TRIALS
%----------------------------------------------   
for j=1:N
    YXX=0; yj=[]; Cj=[];
    for t=1:T
       yj=[]; Cj=[];
        [y,C]=cmat_sezen_HAAR(Y,X,t,p,pmax,decalage,scaleSTART,scalemax,HAAR,N);
        for i=1:chans
             yj=[yj;y((i-1)*N+j,:)];
             Cj=[Cj;C((i-1)*N+j,:)];
        end
        YXX=YXX+(yj-Cj*Xfin(:,t))*(yj-Cj*Xfin(:,t))'+Cj*Pfin(:,:,t)*Cj';
    end
    YX(:,:,j)=YXX;
end


%----------------------------------
A3=0;
for t=2:T
    term=Pfin(:,:,t)+Xfin(:,t)*Xfin(:,t)';
    A3=A3+term;
end
%----------------------------------
A2=0;
for t=1:T-1
    term=Pfin(:,:,t)+Xfin(:,t)*Xfin(:,t)';
    A2=A2+term;
end
%----------------------------------
A1=0;
for t=2:T
    term=Pcov(:,:,t)+Xfin(:,t)*Xfin(:,t-1)';
    A1=A1+term;
end
%----------------------------------

    
% M STEP
%-------------------------------------------------------------------------

   
% A and alf
[A,Ac,alf,ba,bap,ca,bq]=statA_alf_delt(iQ,A1,A2,ap,alf,cap,bap,bp,chans,idelt);

% Q
[iQ,Q,D,q,iaq,Dp,aqq,bqq]=statQ2(A,Ac,A1,A2,A3,qp,T,iaq,bpq,Dp);


% R
% [iR,R,B,r,iar,aqr,bqr,Bp]=statR2(R,YX,rp,T,iar,vr,bpr,Bp,N);
[R,B,r,iar,aqr,bqr,Bp]=statR2FULL(R,YX,rp,T,iar,vr,bpr,Bp,N);


%--------------------------------------------------------------------------
% COMPUTE FREE ENERGY
%--------------------------------------------------------------------------

% --------------------------
% ALFA
%--------------------------
% iba =theta (Wikipedia)
%--------------------------
% kl_alf=KL_gammaMOI(ca,iba,cap,ibap,1);
     
kl_alf=KL_gammaMOI(ca,ba,cap,bap,k);


%---------------------
% kl_aq=KL_gammaMOI(aqq(1),bqq(1),apq(1),bpq(1),1);
kl_aq=KL_gammaMOI(aqq,bqq,apq,bpq,k);

%---------------------

%---------------------
kl_alfalf=KL_gammaMOI(1,bq,cp,bp,1);
%---------------------

for i=1:k
    kl_A(i)=DIVERGENCE_A(A(i,i),ap(i),Ac(i,i),1/alf(i),p);
end
kl_A=sum(kl_A(i));


% --------------------------
% Q
% --------------------------
% [kl_Q]=DIVERGENCE_Q(qp,Dp,q,D,modeQ,p,k); 
kl_Q=KL_gammaMOI(q/2,diag(diag(D))/2,qp/2,diag(Dp)/2,k);

%-----------------------------
% BEAL FREE ENERGY CALCULUS
%-----------------------------
%--------------------------------------------------------------------------
% DIVERGENCES for R and ar
%--------------------------------------------------------------------------

% [kl_R]=DIVERGENCE_R(rp,Bp,r,B)

for j=1:N
    kl_R1(j)=KL_gammaMOI(r/2,(diag(B(:,:,j)))/2,rp/2,(diag(Bp(:,:,j)))/2,2);
    
end

 kl_R=sum(kl_R1);

%---------------------
kl_ar=KL_gammaMOI(aqr,bqr,apr,bpr,chans);
%---------------------
[ZTOT]=Z_BEAL_HAAR(signalTOT_PP,T,p,pmax,decalage,Xpre,Ppre,R,scaleSTART,scalemax,N,chans);
[Ftot]=F_BEAL_SEZEN2(ZTOT,kl_R,kl_Q,kl_A,kl_alf,kl_aq,kl_ar,kl_alfalf);
%--------------------------------------
Ffin(cycle)=Ftot;

%------------------------------------------
% CONTROL FOR THE POSITIVENESS OF DIVERGENCES
%------------------------------------------
if  (kl_Q < 0) 
     fprintf(' VIOLATION DIVERGENCE Q!!!!!!!!\n');
     kl_Q;
end 
if  (kl_R < 0)
    fprintf(' VIOLATION DIVERGENCE R!!!!!!!!\n');
    kl_R;
end
if  (kl_A < 0)
    fprintf(' VIOLATION DIVERGENCE A!!!!!!!!\n');
    kl_A;
end  
if  (kl_alf < 0)
    fprintf(' VIOLATION DIVERGENCE alf!!!!!!!!\n');
    kl_alf;
end

%------------------------------------------
% CONTROL FOR CONVERGENCE
%------------------------------------------


% BASE SUR F
%------------------------------------------

% if (cycle<=2)
%     fbase=Ftot;
% elseif (Ftot<oldf)
%     fprintf(' violation!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n');
%     CONV=0;      
%     break;
%  elseif ((Ftot-oldf) < tol)
%  elseif ((Ftot-fbase)<(1+tol)*(oldf-fbase)|~isfinite(Ftot))
%     CONV=1; 
%     break;
% end
% 
% diff=Ftot-oldf;
% 
% fprintf('Ftot-oldF %d \n',diff);  


% BASE SUR Xfin
%-----------------------------------------

for t=1:T
    diff(t)=(Xfinpre(:,t)-Xfin(:,t))'*(Xfinpre(:,t)-Xfin(:,t));
end
DIFF=sum(diff);

for j=1:k
    VARterm(j)=var(Xfin(j,:));
end
sumVARterm=sum(VARterm);

if (cycle<=2)
    DIFF=DIFF;
%  elseif (DIFF<tol*T*k*sumVARterm)
  elseif (DIFF<tol)

    break;
end



%----------------------------%----------------------------
% UPDATE HYPERPARAMETERS
%----------------------------%----------------------------

%x0 P0 (Ostwald)
%----------------------------
% [y,C]=cmat_sezen_HAAR(X,1,p,pmax,decalage,scale,scaleSTART,scalemax,HAAR);
% x0 = Xfin(:,1) - Pfin(:,:,1)*C'*inv(C*Pfin(:,:,1)*C' + R)*(y - C*Xfin(:,1));
% P0= Pfin(:,:,1) + Pfin(:,:,1)*C'*inv(C*Pfin(:,:,1)*C' + R)*C*Pfin(:,:,1);
%----------------------------%----------------------------

% Perr=(Xfin(:,1)-x0)*(Xfin(:,1)-x0)';
x0 = Xfin(:,1);
P0= Pfin(:,:,1);%+Perr;
%  
%----------------------------%----------------------------

fprintf('cycl= %d \n',cycle);

end

%----------------------------
% VBSSM STRUCTURE BUILDING
%----------------------------

VBSSM.q_A{1}=A; VBSSM.q_A{2}=Ac; 
VBSSM.q_alf{1}=alf; VBSSM.q_alf{2}=ca;VBSSM.q_alf{3}=ba;
VBSSM.q_Q{1}=Q; VBSSM.q_Q{2}=q; VBSSM.q_Q{3}=D;
VBSSM.q_R{1}=R; VBSSM.q_R{2}=r; VBSSM.q_R{3}=B;
VBSSM.KLR=kl_R;
VBSSM.KLQ=kl_Q;
VBSSM.KLA=kl_A;
VBSSM.KLalf=kl_alf;
VBSSM.Z=ZTOT;
VBSSM.KLar=kl_ar;
VBSSM.KLaq=kl_aq;
VBSSM.KLalfalf=kl_alfalf;
VBSSM.pfinal=p;
VBSSM.tol=tol;
VBSSM.decalage=decalage;

end

         