classdef classregtree
%CLASSREGTREE Create a classification and regression tree object.
%   T = CLASSREGTREE(X,Y) creates a decision tree T for predicting response Y as
%   a function of predictors X.  X is an N-by-M matrix of predictor values.
%   If Y is a vector of N response values, then CLASSREGTREE performs
%   regression.  If Y is a categorical variable, character array, or cell
%   array of strings, CLASSREGTREE performs classification.  Either way, T is
%   a binary tree where each non-terminal node is split based on the values
%   of a column of X.  NaN values in X or Y are taken to be missing values,
%   and observations with any missing values are not used in the fit.
%
%   T = CLASSREGTREE(X,Y,'PARAM1',val1,'PARAM2',val2,...) specifies optional
%   parameter name/value pairs:
%
%   For all trees:
%      'categorical' Vector of indices of the columns of X that are to be
%                   treated as unordered categorical variables
%      'method'     Either 'classification' (default if Y is text or a
%                   categorical variable) or 'regression' (default if Y is
%                   numeric)
%      'names'      A cell array of names for the predictor variables,
%                   in the order in which they appear in the X matrix
%                   from which the tree was created (see TREEFIT)
%      'prune'      'on' (default) to compute the full tree and the optimal
%                   sequence of pruned subtrees, or 'off' for the full tree
%                   without pruning
%      'splitmin'   A number K such that impure nodes must have K or more
%                   observations to be split (default 10)
%
%   For classification trees only:
%      'cost'       Square matrix C, where C(i,j) is the cost of classifying
%                   a point into class j if its true class is i (default
%                   has C(i,j)=1 if i~=j, and C(i,j)=0 if i=j).  Alternatively,
%                   this value can be a structure S having two fields:  S.group
%                   containing the group names as a categorical variable,
%                   character array, or cell array of strings; and S.cost
%                   containing the cost matrix C.
%      'splitcriterion'  Criterion for choosing a split. One of 'gdi' (default)
%                   for Gini's diversity index, 'twoing' for the twoing rule,
%                   or 'deviance' for maximum deviance reduction.
%      'priorprob'  Prior probabilities for each class, specified as a
%                   vector (one value for each distinct group name) or as a
%                   structure S with two fields:  S.group containing the group
%                   names as a categorical variable, character array, or cell
%                   array of strings; and S.prob containing a vector of
%                   corresponding probabilities.
%
%   Example:  Create classification tree for Fisher's iris data.
%      load fisheriris;
%      t = classregtree(meas, species,'names',{'SL' 'SW' 'PL' 'PW'});
%      view(t);
%
%   See also CLASSREGTREE/EVAL, CLASSREGTREE/TEST, CLASSREGTREE/VIEW, CLASSREGTREE/PRUNE.

%   Copyright 2006-2007 The MathWorks, Inc. 
%   $Revision: 1.1.6.6 $  $Date: 2007/06/14 05:26:04 $

    properties(GetAccess='private', SetAccess='private')
        method = 'regression';
          node = zeros(0,1);
        parent = zeros(0,1);
         class = zeros(0,1);
           var = zeros(0,1);
           cut = zeros(0,1);
      children = zeros(0,2);
      nodeprob = zeros(0,1);
       nodeerr = zeros(0,1);
          risk = zeros(0,1);
      nodesize = zeros(0,1);
         npred = 0;
       catcols = [];
         prior = [];
      nclasses = 1;
          cost = [];
     classprob = [];
    classcount = [];
     classname = {};
      catsplit = cell(0,2);
     prunelist = zeros(0,1);
         alpha = [];
    ntermnodes = [];
         names = {};

    end
    
    methods
    function a = classregtree(x,y,varargin)
        if ~(nargin==1 && isa(x,'struct'))
            error(nargchk(2,Inf,nargin,'struct'));
            a = treefit(x,y,varargin{:});
        else
            t = x;
            % Look at all fields required for regression or classification trees
            allfields = {'method'   'node'     'parent'   'class'   'var' ...
                         'cut'      'children' 'nodeprob' 'nodeerr' 'risk' ...
                         'nodesize' 'npred'    'catcols'  ...
                         'nclasses' 'prior'    'cost'     'classprob' 'classcount' 'classname'};
             fn = fieldnames(t);
             if ~ismember('method',fn) || ...
                (isequal(t.method,'classification') && ~all(ismember(allfields,fn))) || ...
                (isequal(t.method,'regression')     && ~all(ismember(allfields(1:13),fn)))
                error('stats:treetest:BadTree',...
                      'Input structure argument is not a valid decision tree.');
             end
             if isequal(t.method,'regression')
                 nrequired = 13;
             else
                 nrequired = numel(allfields);
             end
             for j=1:nrequired
                 fname = allfields{j};
                 a.(fname) = t.(fname);
             end
             
             % Look at optional fields
             optionalfields = {'names' 'prunelist' 'alpha' 'ntermnodes' 'catsplit'};
             for j=1:numel(optionalfields)
                 fname = optionalfields{j};
                 if isfield(t,fname)
                     a.(fname) = t.(fname);
                 end
             end
        end
        
        end % classregtree constructor
    end % methods block

    methods(Visible = false)        
        % Methods that we inherit from opaque, but do not want
        function varargout = cat(varargin),     throwNoCatError; end
        function varargout = horzcat(varargin), throwNoCatError; end
        function varargout = vertcat(varargin), throwNoCatError; end

        % Methods that we support but don't provide m-file help
        function a = fields(varargin),          a = fieldnames(varargin{:}); end
        
        function a = fieldnames(varargin)
            a = {'method'; 'node'; 'parent'; 'class'; 'var'; 'cut'; 'children'; ...
                 'nodeprob'; 'nodeerr'; 'risk'; 'nodesize'; 'npred'; 'catcols'; ...
                 'prior'; 'nclasses'; 'cost'; 'classprob'; 'classcount'; ...
                 'classname'; 'catsplit'; 'prunelist'; 'alpha'; 'ntermnodes'};
        end
        
    end % invisible methods block
end % classdef

function throwNoCatError
st = dbstack;
name = strread(st(2).name,'classregtree.%s');
error('stats:classregtree:NoCatAllowed', ...
      'Concatenation of CLASSREGTREE objects not allowed.\nUse a cell array to contain multiple objects.');
end

% ---------------------------------
function Tree=treefit(X,y,varargin)

% Process inputs
if isnumeric(y)
   Method = 'regression';
else
   Method = 'classification';
   if (ischar(y))
      y = cellstr(y);
   end
end
if ~isnumeric(X)
   error('stats:treefit:BadData','X must be a numeric matrix.');
end
okargs =   {'priorprob'   'cost'  'splitcriterion' 'splitmin'  ...
            'categorical' 'prune' 'method'         'names'};
defaults = {[]            []      'gdi'            10          ...
            []            'on'    Method           {}};
[eid,emsg,Prior,Cost,Criterion,Splitmin,categ,Prune,Method,names,extra] = ...
               dfswitchyard('statgetargs',okargs,defaults,varargin{:});
if ~isempty(emsg)
   error(sprintf('stats:treefit:%s',eid),emsg);
end

% For backwards compatibility. 'catidx' is a synonym for 'categorical'
for j=1:2:length(extra)
    if strncmp(extra{j},'catidx',length(extra{j}))
        categ = extra{j+1};
    else
        error('stats:treefit:BadParamName',...
              'Invalid parameter name:  %s.',extra{j});
    end
end

if ~ischar(Method) || isempty(Method) || ~(Method(1)=='c' || Method(1)=='r')
   error('stats:treefit:BadMethod',...
         'Value of ''method'' parameter must be ''classification'' or ''regression''.');
elseif Method(1)=='c'
   Method = 'classification';
else
   Method = 'regression';
end

t = any(isnan(X),2);
if isequal(Method,'regression')
   t = t | isnan(y);
end
if any(t)
   X(t,:) = [];
   y(t) = [];
end

[N,nvars] = size(X);
doclass = isequal(Method(1),'c');
if ~isempty(names)
    if ischar(names)
        names = cellstr(names);
    end
    if ~iscellstr(names) || numel(names)~=nvars
        error('stats:treefit:BadNames',...
              'NAMES must be a character array or cell array with %d strings.',nvars);
    end
else
    names = strcat('x',strread(sprintf('%d\n',1:nvars),'%s\n'));
end
if doclass
   switch(Criterion)
    %                Criterion function   Is it an impurity measure?
    %                ------------------   --------------------------
    case 'gdi',      critfun = @gdi;      impurity = 1;
    case 'twoing',   critfun = @twoing;   impurity = 0;
    case 'deviance', critfun = @deviance; impurity = 1;
    otherwise,     error('stats:treefit:BadSplitCriterion',...
                         'Bad value for ''splitcriterion'' parameter.')
   end
   
   % Get binary matrix, C(i,j)==1 means point i is in class j
   if islogical(y)
      y = double(y);
   end
   if numel(y)~=N
      error('stats:treefit:InputSizeMismatch','Y does not have %d elements.',N);
   end
   [y,cnames] = grp2idx(y);   % find groups only after NaNs removed from X
   if any(isnan(y))
      t = isnan(y);
      y(t) = [];
      X(t,:) = [];
      N = size(X,1);
   end
   nclasses = max(y);
   C = false(N,nclasses);
   C(sub2ind([N nclasses],(1:N)',y)) = 1;
   Nj = sum(C,1);
else
   C = y(:);
end

% Tree structure fields ([C] only for classification trees):
%  .method     method
%  .node       node number
%  .parent     parent node number
%  .class      class assignment for points in this node if treated as a leaf
%  .var        column j of X matrix to be split, or 0 for a leaf node,
%              or -j to treat column j as categorical
%  .cut        cutoff value for split (Xj<cutoff goes to left child node),
%              or index into catsplit if var is negative
%  .children   matrix of child nodes (2 cols, 1st is left child)
%  .nodeprob   probability p(t) for this node
%  .nodeerr    resubstitution error estimate r(t) for this node
%  .risk       R(t) = p(t)*r(t)
%  .nodesize   number of points at this node
%  .prunelist  list of indices that define pruned subtrees.  One entry per
%              node.  If prunelist(j)=k then, at the kth level of pruning,
%              the jth node becomes a leaf (or drops off the tree if its
%              parent also gets pruned).
%  .alpha      vector of complexity parameters for each pruning cut
%  .ntermnodes vector of terminal node counts for each pruning cut
%  .catsplit   call array for categorical splits,
%              left categories in column 1 and right categories in column 2
%  .classprob  [C] vector of class probabilities
%  .classname  [C] names of each class
%  .classcount [C] count of members of each class
%  .nclasses   [C] number of classes
%  .cost       [C] misclassification cost

nodenumber = zeros(N,1);
parent = zeros(N,1);
yfitnode = zeros(N,1);
cutvar = zeros(N,1);
cutpoint = zeros(N,1);
children = zeros(N,2);
nodeprob = zeros(N,1);
resuberr = zeros(N,1);
risk = zeros(N,1);
nodesize = zeros(N,1);
if doclass
   classprob = zeros(N,nclasses);
   classcount = zeros(N,nclasses);
end
catsplit = cell(0,2);
iscat = zeros(nvars,1); iscat(categ) = 1;

nodenumber(1) = 1;

assignednode = ones(N,1);
nextunusednode = 2;

if doclass
   % Get default or specified prior class probabilities
   Prior = Prior(:)';
   haveprior = true;
   if isempty(Prior)
      Prior = Nj / N;
      haveprior = false;
   elseif isequal(Prior,'equal')
      Prior = ones(1,nclasses) / nclasses;

   elseif isstruct(Prior)
      if ~isfield(Prior,'group') || ~isfield(Prior,'prob')
         error('stats:treefit:BadPrior',...
              'Missing field in structure value for ''priorprob'' parameter.');
      end
      idx = getclassindex(cnames,Prior.group);
      if any(idx==0)
         j = find(idx==0);
         error('stats:treefit:BadPrior',...
               'Missing prior probability for group ''%s''.',cnames{j(1)});
      end
      Prior = Prior.prob(idx);
   end
   if length(Prior)~=nclasses || any(Prior<0) || sum(Prior)==0 ...
                              || ~isnumeric(Prior)
      error('stats:treefit:BadPrior',...
            'Value of ''priorprob'' parameter must be a vector of %d probabilities.',...
            nclasses);
   else
      Prior = Prior / sum(Prior);
   end

   % Get default or specified misclassification costs
   havecosts = true;
   if isempty(Cost)
      Cost = ones(nclasses) - eye(nclasses);
      havecosts = false;
   else
      if isstruct(Cost)
         if ~isfield(Cost,'group') || ~isfield(Cost,'cost')
            error('stats:treefit:BadCost',...
                  'Missing field in structure value for ''cost'' parameter.');
         end
         idx = getclassindex(cnames,Cost.group);
         if any(idx==0)
            j = find(idx==0);
            error('stats:treefit:BadCost',...
                  'Missing misclassification cost for group ''%s''.',...
                          cnames{j(1)});
         end
         Cost = Cost.cost(idx,idx);
      end
      if ~isequal(size(Cost),nclasses*ones(1,2))
         error('stats:treefit:BadCost',...
               'Misclassification cost matrix must be %d-by-%d.',...
                       nclasses,nclasses);
      elseif any(diag(Cost)~=0)
         error('stats:treefit:BadCost',...
            'Misclassification cost matrix must have zeros on the diagonal.');
      elseif any(Cost<0)
         error('stats:treefit:BadCost',...
            'Misclassification cost matrix must contain non-negative values.');
      end
   end
   
   % Adjust priors if required to take misclassification costs into account
   adjprior = Prior;
   if havecosts
      Cj = sum(Cost,2)';
      pc = Cj .* Prior;
      adjprior = pc / sum(pc);
   end
end

% Keep processing nodes until done
tnode = 1;
while(tnode < nextunusednode)
   % Record information about this node
   noderows = find(assignednode==tnode);
   Nnode = length(noderows);
   Cnode = C(noderows,:);
   if doclass
      % Compute class probabilities and related statistics for this node
      Njt = sum(Cnode,1);    % number in class j at node t
      Pjandt = Prior .* Njt ./ Nj;
      Pjgivent = Pjandt / sum(Pjandt);
      misclasscost = Pjgivent * Cost;
      [mincost,nodeclass] = min(misclasscost);
      yfitnode(tnode) = nodeclass;
      Pt = sum(Pjandt);
      nodeprob(tnode) = Pt;
      classprob(tnode,:) = Pjgivent;
      classcount(tnode,:) = Njt;
      pratio = adjprior ./ Nj;
      impure = sum(Pjgivent>0)>1;
   else
      % Compute variance and related statistics for this node
      ybar = mean(Cnode);
      yfitnode(tnode) = ybar;
      nodeprob(tnode) = Nnode/N;
      sst = norm(Cnode-ybar)^2;   % total sum of squares at this node
      mincost = sst / Nnode;
      impure = (mincost > 1e-6 * resuberr(1));
   end
   bestcrit          = -Inf;
   nodesize(tnode)   = Nnode;
   resuberr(tnode)   = mincost;
   risk(tnode)       = nodeprob(tnode) * resuberr(tnode);
   cutvar(tnode)     = 0;
   cutpoint(tnode)   = 0;
   children(tnode,:) = 0;
   
   % Consider splitting this node
   if (Nnode>=Splitmin) && impure      % split only large impure nodes
      Xnode = X(noderows,:);
      bestvar = 0;
      bestcut = 0;

      % Find the best of all possible splits
      for jvar=1:nvars
         [x,idx] = sort(Xnode(:,jvar));          % get sorted jth x variable
         
         % Determine if there's anything to split along this variable
         maxeps = max(eps(x(1)), eps(x(end)));
         if x(1)+maxeps > x(end)
            continue;
         end
         rows = find(x(1:end-1)+maxeps < x(2:end));
         if isempty(rows)
            continue;
         end

         xcat = iscat(jvar);
         if doclass
            Ccum = cumsum(Cnode(idx,:));         % cum. class counts
            [critval,cutval]=Ccritval(x,Ccum,rows,xcat,pratio,Pt,impurity,critfun,bestcrit);
         else
            ycum = cumsum(Cnode(idx,:) - ybar);  % centered response cum. sum
            [critval,cutval]=Rcritval(x,ycum,rows,xcat);
         end

         % Change best split if this one is best so far
         if critval>bestcrit
            bestcrit = critval;
            bestvar = jvar;
            bestcut = cutval;
         end
      end

      % Split this node using the best rule found
      if bestvar~=0
         x = Xnode(:,bestvar);
         if ~iscat(bestvar)
            cutvar(tnode) = bestvar;
            cutpoint(tnode) = bestcut;
            leftside = x<=bestcut;
            rightside = ~leftside;
         else
            cutvar(tnode) = -bestvar;          % negative indicates cat. var. split
            ncatsplit = size(catsplit,1) + 1;  % index into catsplit cell array
            cutpoint(tnode) = ncatsplit;
            catsplit(ncatsplit,:) = bestcut;
            leftside = ismember(x,bestcut{1});
            rightside = ismember(x,bestcut{2});
         end
         children(tnode,:) = nextunusednode + (0:1);
         assignednode(noderows(leftside)) = nextunusednode;
         assignednode(noderows(rightside)) = nextunusednode+1;
         nodenumber(nextunusednode+(0:1)) = nextunusednode+(0:1)';
         parent(nextunusednode+(0:1)) = tnode;
         nextunusednode = nextunusednode+2;
      end
   end
   tnode = tnode + 1;
end

topnode        = nextunusednode - 1;
Tree.method    = Method;
Tree.node      = nodenumber(1:topnode);
Tree.parent    = parent(1:topnode);
Tree.class     = yfitnode(1:topnode);
Tree.var       = cutvar(1:topnode);
Tree.cut       = cutpoint(1:topnode);
Tree.children  = children(1:topnode,:);
Tree.nodeprob  = nodeprob(1:topnode);
Tree.nodeerr   = resuberr(1:topnode);
Tree.risk      = risk(1:topnode);
Tree.nodesize  = nodesize(1:topnode);
Tree.npred     = nvars;
Tree.catcols   = categ;
Tree.names     = names;
if doclass
   if ~haveprior, Prior=[]; end
   Tree.prior     = Prior;
   Tree.nclasses  = nclasses;
   Tree.cost      = Cost;
   Tree.classprob = classprob(1:topnode,:);
   Tree.classcount= classcount(1:topnode,:);
   Tree.classname = cnames;
end

Tree.catsplit  = catsplit; % list of all categorical predictor splits

Tree = removebadsplits(Tree);

if isequal(Prune,'on')
   Tree = treeprune(Tree);
end
end

%----------------------------------------------------
function v=gdi(p)
%GDI Gini diversity index

v=1-sum(p.^2,2);
end

%----------------------------------------------------
function v=twoing(Pleft, P1, Pright, P2)
%TWOING Twoing index

v = 0.25 * Pleft .* Pright .* sum(abs(P1-P2),2).^2;
end

%----------------------------------------------------
function v=deviance(p)
%DEVIANCE Deviance

v = -2 * sum(p .* log(max(p,eps(class(p)))), 2);
end

%----------------------------------------------------
function [critval,cutval]=Ccritval(x,Ccum,rows,iscat,pratio,Pt,impurity,critfun,bestcrit)
%CCRITVAL Get critical value for splitting node in classification tree.
   
% First get all possible split points

% Get arrays showing left/right class membership at each split
nsplits = length(rows);
if iscat
   % B contains the class counts in each category
   t = [rows; size(Ccum,1)];
   B = Ccum(t,:);
   B(2:end,:) = B(2:end,:) - B(1:end-1,:);

   Bsums = sum(B,1);
   Bcats = sum(Bsums>0);
   if Bcats>2
      % We have three or more response categories
      % A picks out all category subsets including the 1st,
      % but not the whole set
      A = ones(2^nsplits,nsplits+1);
      A(:,2:end) = fullfact(2*ones(1,nsplits)) - 1;
      A(end,:) = [];
   else
      % We have just two categories, so pick subsets by order of mean proportion
      Bcol = find(Bcats>1,1);
      if isempty(Bcol)
         Bcol = 1;
      end

      % A contains the category subsets, arranged in order of response mean
      catmeans = B(:,Bcol) ./ max(1, diff([0; t]));
      [smeans,sorder] = sort(catmeans);
      ncat = length(catmeans);
      A = zeros(ncat-1,ncat);
      T = tril(ones(ncat));
      A(:,sorder) = T(1:end-1,:);
   end

   Csplit1 = A*B;
   nsplits = size(Csplit1,1);
   allx = x(t);
else
   % Split between each pair of distinct ordered values
   Csplit1 = Ccum(rows,:);
end
Csplit2 = Ccum(size(Ccum,1)*ones(nsplits,1),:) - Csplit1; % repmat(Ccum(end,:),nsplits,1) - Csplit1;

% Get left/right class probabilities at each split
temp = pratio(ones(nsplits,1),:); %repmat(pratio,nsplits,1);
P1 = temp .* Csplit1;
P2 = temp .* Csplit2;
Ptleft  = sum(P1,2);
Ptright = sum(P2,2);
nclasses = size(P1,2);
wuns = ones(1,nclasses);
P1 = P1 ./ Ptleft(:,wuns);   %repmat(Ptleft,1,nclasses);
P2 = P2 ./ Ptright(:,wuns);  %repmat(Ptright,1,nclasses);

% Get left/right node probabilities
Pleft = Ptleft ./ Pt;
Pright = 1 - Pleft;

% Evaluate criterion as impurity or otherwise
if impurity
   crit = - Pleft.*feval(critfun,P1);
   t = (crit>bestcrit);   % compute 2nd term only if it would make a difference
   if any(t)
      crit(t) = crit(t) - Pright(t).*feval(critfun,P2(t,:));
   end
else
   crit = feval(critfun, Pleft, P1, Pright, P2);
end

% Return best split point, but bail out early if no improvement
critval = max(crit);
if critval<bestcrit
   cutval = 0;
   return;
end

maxloc = find(crit==critval);
if length(maxloc)>1
   maxloc = maxloc(1+floor(length(maxloc)*rand));
end
if iscat
   t = logical(A(maxloc,:));
   xleft = allx(t);
   xright = allx(~t);
   cutval = {xleft' xright'};
else
   cutloc = rows(maxloc);
   cutval = (x(cutloc) + x(cutloc+1))/2;
end
end

%----------------------------------------------------
function [critval,cutval]=Rcritval(x,Ycum,rows,iscat)
%RCRITVAL Get critical value for splitting node in regression tree.
   
% First get all possible split points

% Get arrays showing left/right class membership at each split
if iscat
   % B contains the category sums
   t = [rows; size(Ycum,1)];
   B = Ycum(t,:);
   B(2:end,:) = B(2:end,:) - B(1:end-1,:);

   % A contains the category subsets, arranged in order of response mean
   catmeans = B ./ max(1, diff([0; t]));
   [smeans,sorder] = sort(catmeans);
   ncat = length(catmeans);
   A = zeros(ncat-1,ncat);
   T = tril(ones(ncat));
   A(:,sorder) = T(1:end-1,:);

   Ysplit1 = A*B;
   n1 = A*[t(1);diff(t)];
   allx = x(t);               % take one x value from each unique set
else
   % Split between each pair of distinct ordered values
   Ysplit1 = Ycum(rows,:);
   n1 = rows;
end

% Get left/right means
N = numel(x);
mu1 = Ysplit1 ./ n1;
mu2 = (Ycum(end) - Ysplit1) ./ (N - n1);

ssx = n1.*mu1.^2 + (N-n1).*mu2.^2;
critval = max(ssx);
maxloc = find(ssx==critval);
if length(maxloc)>1
   maxloc = maxloc(1+floor(length(maxloc)*rand));
end
if iscat
   t = logical(A(maxloc,:));
   xleft = allx(t);
   xright = allx(~t);
   cutval = {xleft' xright'};
else
   cutloc = rows(maxloc);
   cutval = (x(cutloc) + x(cutloc+1))/2;
end
end

% --------------------------------------
function Tree = removebadsplits(Tree)
%REMOVEBADSPLITS Remove splits that contribute nothing to the tree.

N = length(Tree.node);
isleaf = (Tree.var==0)';   % no split variable implies leaf node
isntpruned = true(1,N);
doprune = false(1,N);
risk = Tree.risk';
adjfactor = (1 - 100*eps(class(risk)));

% Work up from the bottom of the tree
while(true)
   % Find "twigs" with two leaf children
   branches = find(~isleaf & isntpruned);
   twig = branches(sum(isleaf(Tree.children(branches,:)),2) == 2);
   if isempty(twig)
      break;            % must have just the root node left
   end
   
   % Find twigs to "unsplit" if the error of the twig is no larger
   % than the sum of the errors of the children
   Rtwig = risk(twig);
   kids = Tree.children(twig,:);
   Rsplit = sum(risk(kids),2);
   unsplit = Rsplit >= Rtwig'*adjfactor;
   if any(unsplit)
      % Mark children as pruned, and mark twig as now a leaf
      isntpruned(kids(unsplit,:)) = 0;
      twig = twig(unsplit);   % only these to be marked on next 2 lines
      isleaf(twig) = 1;
      doprune(twig) = 1;
   else
      break;
   end
end

% Remove splits that are useless
if any(doprune)
   Tree = treeprune(Tree,'nodes',find(doprune));
end
end

% ------------------------------------
function idx = getclassindex(cnames,g)
%GETCLASSINDEX Find indices for class names in another list of names
%   IDX = GETCLASSINDEX(CNAMES,G) takes a list CNAMES of class names
%   (such as the grouping variable values in the treefit or classify
%   function) and another list G of group names (as might be supplied
%   in the "prior" argument to those functions), and finds the indices
%   of the CNAMES names in the G list.  CNAMES should be a cell array
%   of strings.  G can be numbers, a string array, or a cell array of
%   strings

% Convert to common string form, whether input is char, cell, or numeric
if isnumeric(g)
   g = cellstr(strjust(num2str(g(:)), 'left'));
elseif ~iscell(g)
   g = cellstr(g);
end

nclasses = length(cnames);
idx = zeros(1,nclasses);

% Look up each class in the grouping variable.
for i = 1:nclasses
   j = strmatch(cnames{i}, g, 'exact');
   if ~isempty(j)
      idx(i) = j(1);
   end
end
end