function [ht, err, sv] = fixed_truncate_ltr(x, opts)
%FIXED_TRUNCATE_LTR Truncate full tensor to htensor, leafs-to-root.
%
%   This is a slightly modified version of TRUNCATE_LTR in the 
%       Hierarchical Tucker Toolbox 
%       Christine Tobler, ETH Zurich, and D. Kressner, EPF Lausanne
%   to perform truncation at user-specified ranks.
%   See TRUNCATE_LTR for the original documentation.
%
%   Y = FIXED_TRUNCATE_LTR(X, OPTS) truncates a multidimensional array 
%   X to an htensor Y, according to OPTS:
%     OPTS.H_RANKS (mandatory): explicit hierarchical ranks.

% Geometric Hierchical Tucker 
% Copyright 2012 Bart Vandereycken, EPF Lausanne
% GPLv3 License, see COPYING.txt for details.

if(nargin ~= 2)
  error('Requires exactly 2 arguments.')
end

if(~isfloat(x))
  error('First argument must be a real or complex multidimensional array.')
end

if ~isa(opts, 'struct') || ~isfield(opts, 'h_ranks')
  error(['Second argument must be a MATLAB struct with the ,' ...
	       ' field h_rank.']);
end

% Initialize htensor t
if(isfield(opts, 'tree_type'))
  ht = htensor(size(x), opts.tree_type);
elseif(isfield(opts, 'children') && isfield(opts, 'dim2ind'))
  ht = htensor(size(x), opts.children, opts.dim2ind);
else
  ht = htensor(size(x));
end

% initialize cells of matrices U and B:
U = cell(1, ht.nr_nodes);
B = cell(1, ht.nr_nodes);

% Make a temporary copy of tensor x
x_ = full(x);

% err represents the node-wise truncation errors
err = zeros(1, ht.nr_nodes);


ht_is_leaf = ht.is_leaf;
ht_dims = ht.dims;

% Traverse leafs of the dimension tree
for ii=find(ht_is_leaf)
  % Matricization of x corresponding to node ii
  x_mat = matricize(x, ht_dims{ii}, ...
                    [1:ht_dims{ii}-1, ht_dims{ii}+1:ndims(x)], false);
  
  % Calculate left singular vectors U_ and singular values of x_mat
  if(~isfield(opts, 'sv'))
    opts.sv = 'svd';
  end
  
  if(strcmp(opts.sv, 'gramian'))
    [U_, sv{ii}] = htensor.left_svd_gramian(x_mat*x_mat');
  elseif(strcmp(opts.sv, 'svd'))
    [U_, sv{ii}] = htensor.left_svd_qr(x_mat);
  else
    error('Invalid value of OPTS.SV.');
  end
  
%%%%%%%%  
  % Calculate rank k to use, and expected error: this is the new code
  opts2 = opts;
  opts2.expl_rank = opts.h_ranks(ii);
  [k(ii), err(ii)] = fixed_trunc_rank(sv{ii}, opts2);      
%%%%%%%%

  % Save left singular vectors U for later
  U{ii} = U_(:, 1:k(ii));
  
  % Reduce tensor x_ along this dimension
  x_ = ttm(x_, U{ii}, ht_dims{ii}, 'h');
end

% Set x to be the reduced tensor x_
x = x_;

ht_lvl = ht.lvl;

% Go through all levels from leafs to root node
for lvl_iter = max(ht_lvl):-1:0
  % Go through all nodes at given level
  for ii=find(ht_lvl == lvl_iter)
    
    % Leafs have already been treated
    if(ht_is_leaf(ii))
      continue;
    end
    
    % Matricization of x corresponding to node ii
    x_mat = matricize(x, ht_dims{ii});
    
    % special case root node: matricization is a vector
    if(ii == 1)   
      U_ = x_mat;
      k(ii) = 1;
    else
      
      % Calculate left singular vectors U_ and singular values of x_mat
      if(~isfield(opts, 'sv'))
        opts.sv = 'svd';
      end
      
      if(strcmp(opts.sv, 'gramian'))
        [U_, sv{ii}] = htensor.left_svd_gramian(x_mat*x_mat');
      elseif(strcmp(opts.sv, 'svd'))
        [U_, sv{ii}] = htensor.left_svd_qr(x_mat);
      else
        error('Invalid argument OPTS.SV.');
      end
      
      % Calculate rank k to use, and expected error: this is new code
      opts2 = opts;
      opts2.expl_rank = opts.h_ranks(ii);
      [k(ii), err(ii)] = fixed_trunc_rank(sv{ii}, opts2);      
      
      % Cut U_ after first k columns
      U_ = U_(:, 1:k(ii));
    end
    
    % Child nodes' indices
    ii_left  = ht.children(ii, 1);
    ii_right = ht.children(ii, 2);
    
    % reshape B{ii} from matrix U_ to a 
    % k(ii) x k(ii_left) x k(ii_right) tensor, 
    B{ii} = dematricize(U_, [k(ii_left), k(ii_right), k(ii)], ...
                        [1 2], 3, false);
			  
    % Reduce tensor x_ along dimensions x.dims{ii}; this will
    % change the number of dimensions of x_:
    
    % Matricization of x_, making dims{ii} the row dimensions
    x_mat_ = matricize(x_, ht_dims{ii});
    
    % calculate B{ii}'*x_mat_
    U_x_mat = U_'*x_mat_;
    
    % Instead of eliminating one of the dimensions, just set
    % it to be a singleton, to keep the dimension order consistent
    tsize_red = size(x_); tsize_red(end+1:ndims(ht)) = 1;
    tsize_red(ht_dims{ii_left }(1)) = k(ii);
    tsize_red(ht_dims{ii_right}(1)) = 1;
    
    % Reshape x_mat_ to tensor x_
    x_ = dematricize(U_x_mat, tsize_red, ht_dims{ii});
  end
  
  % Set x to be the reduced tensor x_  
  x = x_;
end

% Call htensor constructor
ht = htensor(ht.children, ht.dim2ind, U, B, true);

% Display the estimated errors
if(isfield(opts, 'disp_errtree') && opts.disp_errtree == true)
  disp_tree(ht, 'truncation_error', err);
  
  % We know from theory that
  %
  % ||X - X_best|| <= ||X - X_|| <= err_bd <= factor*||X - X_best||
  %
  % and max(err) <= ||X - X_best||, therefore
  %
  % max(err_bd/factor, max(err)) <= ||X - X_best|| <= ||X - X_|| <= err_bd
  %
  % give upper and lower bounds for the best approximation as well
  % as the truncated version constructed here.
  
  % Count top-level truncation only once
  err_ = err; err_(ht.children(1, 1)) = 0;

  % Calculate upper bound and factor c from ||x - x_|| <= c ||x - x_best||
  err_bd = norm(err_); factor = sqrt(2*ndims(ht)-3);
  
  fprintf(['\nLower/Upper bound for best approximation error:\n' ...
	   '%e <= ||X - X_best|| <= ||X - X_|| <= %e\n'], ...
	  max(err_bd/factor, max(err)), err_bd);
end

end


function [k, err, success] = fixed_trunc_rank(s, opts)
%FIXED_TRUNC_RANK Return rank according to user-specified parameters.
%
%   This is a slightly modified version of TRUNC_RANK in the 
%       Hierarchical Tucker Toolbox 
%       Christine Tobler, ETH Zurich, and D. Kressner, EPF Lausanne
%   to perform truncation at user-specified ranks.
%   See TRUNC_RANK for the original documentation.
%
%   CHANGED: OPTS.H_RANKS takes the ranks fixed by the cell array ranks

if(nargin ~= 2)
  error('Requires 2 arguments.')
end

if(~isnumeric(s) || ~isvector(s))
  error('First argument must be a vector.');
end

if(~isa(opts, 'struct') || ~isfield(opts, 'expl_rank') )
  error(['Second argument must be a MATLAB struct with the ' ...
	       ' field expl_rank.']);
end

% When truncating at k, error in Frobenius norm is s_sum(k+1) with
% s_sum(k+1) = norm(s(k+1:end))
s_sum = sqrt(cumsum(s(end:-1:1).^2));
s_sum = s_sum(end:-1:1);

rank = opts.expl_rank;
k = min(rank, length(s));
s_sum = [s_sum; 0];
err = s_sum(k+1);
success = true;

end

