function [ts, err_dyn, err_svd] = do_calc_integrate(x, sz, hk, epsi, cte, Tend)
% Integrate and calculate point-wise SVDs for a time-varying tensor.

% Geometric Hierchical Tucker 
% Copyright 2012 Bart Vandereycken, EPF Lausanne
% GPLv3 License, see COPYING.txt for details.


x = htenrandn(sz, 'orthog', hk);
C = rand(sz); C = C/norm(C(:));
nb_nodes = x.nr_nodes;

%%% Functions A(t) and dA(t)/dt
function At = A(t)
  Bt = x.B;
  Ut = x.U;
  for ii_dA_=1:nb_nodes
    Bt{ii_dA_} = Bt{ii_dA_}*exp(t);
    Ut{ii_dA_} = Ut{ii_dA_}*(2+sin(t));
  end
  xt = htensor(x.children, x.dim2ind, x.U, Bt);
  At = full(xt) + exp(cte*t)*epsi*(t+sin(3*t))*C;
end

function [dAt,At] = dA(t)
  Bt = cell(1,nb_nodes);
  dBt = cell(1,nb_nodes);
  dUt = cell(1,nb_nodes);
  
  for ii_dA_=1:nb_nodes
    Bt{ii_dA_} = x.B{ii_dA_}*exp(t);
    dBt{ii_dA_} = x.B{ii_dA_}*exp(t);
    dUt{ii_dA_} = cos(t)*x.U{ii_dA_};   
  end
  
  xt = htensor(x.children, x.dim2ind, x.U, Bt);
  
  dxt = tangent(xt, dUt, dBt);
  
  if nargout==2
    [dXt, Xt] = full(dxt);
    dAt = dXt + (exp(cte*t)*epsi*(1+3*cos(3*t)) + cte*exp(cte*t)*epsi*(t+sin(3*t))) * C;
    At = Xt + exp(cte*t)*epsi*(t+sin(3*t))*C;
  else
    dXt = full(dxt);
    dAt = dXt + (exp(cte*t)*epsi*(1+3*cos(3*t)) + cte*exp(cte*t)*epsi*(t+sin(3*t))) * C;
  end
end




%%% Standard matlab integration

x0 = htensor(x.children, x.dim2ind, x.U, x.B);
D = full(x0) - A(0);
err_x0 = norm(D(:))

[y0,sizes] = htensor_to_param(x0);

function dydt = dydt(t,y)
  dAt = dA(t);  
  % make x from y
  xt = param_to_htensor(y, sizes);
  
  dx = project_onto_Tspace(xt,dAt);
  
  % extract parameters
  dydt =  tang_to_param(dx);
end
ts = linspace(0,Tend,100);

opts = odeset('RelTol', 1e-3, 'AbsTol', 1e-6, 'NormControl', 'on');
T=tic;
sol = ode45(@dydt, [0 Tend], y0, opts);
TOTAL_TIME_ODE=toc(T)
SCALED_TIME_ODE=TOTAL_TIME_ODE/sol.stats.nfevals
sol.stats 
 
%%% Verify errors

for ii=1:length(ts)
  xt = param_to_htensor(deval(sol,ts(ii)), sizes);
  Atsii = A(ts(ii));
  D = Atsii - full(xt);
  err_dyn(ii) = norm(D(:));
end




%%% Calculate the errors for truncated SVD
opts2.h_ranks = get_h_ranks(x0);
opts2.max_rank = max(sz);
%opts2.abs_eps = 0; opts2.rel_eps = 0;

T=tic;
xts = cell(1, length(ts));
for ii=2:length(ts)
  xts{ii} = fixed_truncate_ltr(A(ts(ii)), opts2);  
end
TOTAL_TIME_SVD=toc(T)
SCALED_TIME_SVD=TOTAL_TIME_SVD/(length(ts)-1)

for ii=2:length(ts)
  D = A(ts(ii)) - full(xts{ii});
  err_svd(ii) = norm(D(:)); 
end



end