function [varargout] = grpstats(x,group,whichstats,alpha)
%GRPSTATS Summary statistics by group.
%   MEANS = GRPSTATS(X,GROUP) returns the MEANS of each column of X by
%   GROUP. X is a matrix of observations.  GROUP is a grouping variable
%   defined as a categorical variable, vector, string array, or cell array
%   of strings.  GROUP can also be a cell array of several grouping
%   variables (such as {G1 G2 G3}) to group the values in X by each unique
%   combination of grouping variable values.  GROUP can be [] or omitted to
%   compute the mean of the entire sample without grouping.  When there is
%   a single grouping variable, groups are sorted by order of appearance
%   (if GROUP is character), sorted numeric value (if GROUP is numeric), or
%   order of the levels property (if GROUP is categorical).
%
%   GRPSTATS(X,GROUP,ALPHA) displays a plot of the means versus index with
%   100(1 - ALPHA)%  confidence intervals around each mean.
%
%   [A,B,...] = GRPSTATS(X,GROUP,WHICHSTATS} returns the statistics specified
%   in WHICHSTATS.  WHICHSTATS can be a single function handle or name, or a
%   cell array containing multiple function handles or names.  The number of
%   outputs (A,B,...) must match the number function handles and names in
%   WHICHSTATS.  The names can be chosen from among the following:
%
%      'mean'     mean
%      'sem'      standard error of the mean
%      'numel'    count, or number of elements
%      'gname'    group name
%      'std'      standard deviation
%      'var'      variance
%      'meanci'   95% confidence interval for the mean
%      'predci'   95% prediction interval for a new observation
%
%   Each function included in WHICHSTATS must accept a column vector of data
%   and compute a descriptive statistic for it.  For example, @median and
%   @skewness are suitable functions.  The function typically returns a scalar
%   value, but may return an NVALS-by-1 column vector if the descriptive
%   statistic is not a scalar (a confidence interval, for example).  The size
%   of each output A, B, ... is NGROUPS-by-NCOLS-by-NVALS, where NGROUPS is
%   the number of groups, NCOLS is the number of columns in the data X, and
%   NVALS is the number of values returned by the function for data from a
%   single group in one column of X.  If X is a vector of data, then the size
%   of each output A, B, ... is NGROUPS-by-NVALS.
%
%   A function included in WHICHSTATS may also be written to accept a matrix
%   of data and compute a descriptive statistic for each column.  The function
%   should return either a row vector, or an NVALS-by-NCOLS matrix if the
%   descriptive statistic is not a scalar.
%
%   [...] = GRPSTATS(X,GROUP,WHICHSTATS,ALPHA) specifies the confidence
%   level as 100(1-ALPHA)% for the 'meanci' and 'predci' options.  It does
%   not display a plot.
%
%   Example:  
%      load carsmall
%      [m,p,g] = grpstats(Weight,Model_Year,{'mean','predci','gname'})
%      n = length(m)
%      errorbar((1:n)',m,p(:,2)-m)
%      set(gca,'xtick',1:n,'xticklabel',g)
%      title('95% prediction intervals for mean weight by year')
%
%   See also GSCATTER, GRP2IDX.

%   Older syntax still supported:
%   [MEANS,SEM,COUNTS,GNAME] = GRPSTATS(X,GROUP) returns the standard error
%   of the mean in SEM, the number of elements in each group in COUNTS,
%   and the name of each group in GNAME.

%   Copyright 1993-2007 The MathWorks, Inc. 
%   $Revision: 2.15.2.13 $  $Date: 2007/06/14 05:25:24 $

if (nargin<1)
   error('stats:grpstats:TooFewInputs',...
         'GRPSTATS requires at least one argument.')
end
if ndims(x)>2 || ~(isnumeric(x) || islogical(x))
    error('stats:grpstats:BadData',...
          'X must be a numeric or logical vector or matrix.');
elseif isvector(x)
    x = x(:);
end
[rows,cols] = size(x);

% Recognize plotting syntax with alpha in 3rd position
doplot = false;
if nargin==3
    if isnumeric(whichstats) && isscalar(whichstats)
        alpha = whichstats;
        whichstats = {};
        doplot = true;
    else
        alpha = 0.05;
    end
elseif nargin<=2
    alpha = 0.05;   % used in nested functions
    whichstats = {};
end

if alpha<=0 || alpha>=1
    error('stats:grpstats:BadAlpha',...
          'ALPHA must be a number larger than 0 and smaller than 1.');
end

% Get list of statistics functions to call
if isempty(whichstats)
    % Default list
    whichstats = {@(x) mean(x,1), @(x) std(x,0,1) / sqrt(size(x,1)), @(x) size(x,1), 'gname'};
    if doplot
        minargs = 3;
    else
        minargs = 1;
    end
    whichstats = whichstats(1:max(minargs,nargout));
else
    if ~iscell(whichstats)
        whichstats = {whichstats};
    end

    % Convert keywords to function handles
    for j=1:numel(whichstats)
        hfun = whichstats{j};
        if ischar(hfun)
            switch(hfun)
              case 'mean',  hfun = @(x) mean(x,1);
              case 'sem',   hfun = @(x) std(x,0,1) / sqrt(size(x,1));
              case 'std',   hfun = @(x) std(x,0,1);
              case 'var',   hfun = @(x) var(x,0,1);
              case 'numel', hfun = @(x) size(x,1);
              case 'meanci',hfun = @meanci;
              case 'predci',hfun = @predci;
              %otherwise, may be a function name or 'gname'
            end
        whichstats{j} = hfun;
        end
    end
    
    % Warn if they won't get some of what's listed in whichstats; they will get an
    % error if not enough is listed in whichstats
    if max(1,nargout) < numel(whichstats)
        warning('stats:grpstats:ArgumentMismatch',...
                ['GRPSTATS called with %d output arguments to compute %d\n', ...
                 'summary statistics.'],nargout,numel(whichstats));
    end
end

% Get grouping variable information
if (nargin<2) || isempty(group)
   group = ones(rows,1);
end
[group,glabel,groupname,multigroup,ngroups] = mgrp2idx(group,rows);
if length(group) ~= rows
    error('stats:grpstats:InputSizeMismatch',...
          'Must have one GROUP for each row of X.');
end

% Collect group information
groups = cell(1,ngroups);
for gnum = 1:ngroups
    groups{gnum} = find(group==gnum);
end

nfuns = numel(whichstats);
varargout = cell(1,max(1,nfuns));

for nfun = 1:nfuns
    hfun = whichstats{nfun};   % get function handle or name
    if isequal(hfun,'gname')
        % special case for gname, not applied separately to each column
        varargout{nfun} = groupname;
        continue
    end

    % Should we try to apply the function to an entire matrix or just a column?
    trymatrix = (cols~=1) && ~any(isnan(x(:)));

    % Test the function to see what we get
    if isempty(groups)
        rowidx = [];
    else
        rowidx = groups{1};
    end

    if trymatrix
        % Attempt to call the function on a data matrix
        try
            t = feval(hfun,x(rowidx,:));
            if size(t,2)~=cols
                trymatrix = false;
            end
        catch
            trymatrix = false;
        end
    end
    if trymatrix
        % Success, put results for this group into an array
        nstatvals = size(t,1);
        t1 = reshape(t',[1,cols,nstatvals]);   % 1st dim for groups
        z = repmat(t1,[ngroups,1,1]);          % one per group
        tsize = [nstatvals,cols];
    else    
        % Call the function on one column
        if size(x,2)>=1
            y = x(rowidx,1);
        else
            y = x(rowidx,[]);
        end
        if ngroups > 0
            t = tryeval(hfun,y(~isnan(y),:),glabel{1});
        else
            t = tryeval(hfun,y(~isnan(y),:));
        end
        nstatvals = size(t,1);
        t1 = reshape(t,[1,1,nstatvals]);       % dims 1-2 for group,col
        z = repmat(t1,[ngroups,cols,1]);       % one per group and col
        tsize = size(t);
        if ngroups>0 && cols>0
            for colnum = 2:cols
                % Now do the rest of the columns
                y = x(rowidx,colnum);
                z(1,colnum,:) = tryeval(hfun,y(~isnan(y),:),glabel{1},tsize);
            end
        end
    end

    % Now do the rest of the groups
    for gnum = 2:ngroups
        idx = groups{gnum};

        if trymatrix
            z(gnum,:,:) = tryeval(hfun,x(idx,:),glabel{gnum},tsize)';
        else
            for colnum = 1:cols
                y = x(idx,colnum);
                z(gnum,colnum,:) = tryeval(hfun,y(~isnan(y),:),glabel{gnum},tsize);
            end
        end
    end
    
    % Special case:  don't add 3rd dimension of there is just one column
    if cols==1
        z = reshape(z,ngroups,nstatvals);
    end
    varargout{nfun} = z;
end

if doplot
   means = varargout{1};
   sems = varargout{2};
   counts = varargout{3};
   p = 1 - alpha/2;
   xd = repmat((1:ngroups)',1,cols);
   h = errorbar(xd,means,tinv(p,counts-1) .* sems);
   set(h,'Marker','o','MarkerSize',2);
   set(gca,'Xlim',[0.5 ngroups+0.5],'Xtick',(1:ngroups));
   xlabel('Group');
   ylabel('Mean');
   if (multigroup)
      % Turn off tick labels and axis label
      set(gca, 'XTickLabel','','UserData',size(groupname,2));
      xlabel('');
      ylim = get(gca, 'YLim');
      
      % Place multi-line text approximately where tick labels belong
      for j=1:ngroups
         text(j,ylim(1),glabel{j,1},'HorizontalAlignment','center',...
              'VerticalAlignment','top', 'UserData','xtick');
      end

      % Resize function will position text more accurately
      set(gcf, 'ResizeFcn', @resizefcn, 'Interruptible','off');
      doresize(gcf);
   else
      set(gca, 'XTickLabel',glabel);
   end
   title('Means and Confidence Intervals for Each Group');
   set(gca, 'YGrid', 'on');
end

% Nested functions below here; they use alpha from caller
    function ci = meanci(y,m,s,n,d) % m,s,n,d are local variables
    n = size(y,1);
    m = mean(y,1);
    s = std(y,0,1) / sqrt(n);
    d = s * -tinv(alpha/2, max(0,n-1));
    ci = [m-d; m+d];
    end

    % ----------------------------
    function ci = predci(y,m,s,n,d) % m,s,n,d are local variables
    n = size(y,1);
    m = mean(y,1);
    s = std(y,0,1) * sqrt(1 + 1/n);
    d = s * -tinv(alpha/2, max(0,n-1));
    ci = [m-d; m+d];
    end
end


% ----------------------------
function t = tryeval(f,y,glabel,tsize)
errtype = 0;
try
    t = feval(f,y);
    if nargin>=4
        if ~isequal(size(t),tsize)
            errtype = 2;
        end
    else
        if ~(isvector(t) && isequal(size(t,2),1))
            errtype = 2;
        end
    end
catch
    errtype = 1;
end
if errtype>0
    if ischar(f)
        fname = f;
    else
        fname = func2str(f);
    end
    % When there are no data, we don't have a group name
    if nargin >= 3
        glabel(glabel==sprintf('\n')) = '_';
        gtext = sprintf(' when\nevaluating data for group ''%s''',glabel);
    else
        gtext = '';
    end
    if errtype==1
        error('stats:grpstats:FunctionError', ...
              ['The function ''%s'' generated the following error%s:\n\n%s'],fname,gtext,lasterr);
    elseif nargin==4
        error('stats:grpstats:BadFunctionResult', ...
              ['Function ''%s'' returned a result of size [%s]%s, ', ...
               'expected size [%s].'],fname,num2str(size(t)),gtext,num2str(tsize));
    else
        error('stats:grpstats:BadFunctionResult', ...
              ['Function ''%s'' returned a result of size [%s]%s, ', ...
               'expected a scalar or column vector.'],fname,num2str(size(t)),gtext);
    end
end

end

% ----------------
function resizefcn(varargin)
% Resize callback
doresize(gcbf);
end

% -------------------------
function doresize(f)
% Adjust figure layout to make sure labels remain visible
h = findobj(f, 'UserData','xtick');
if (isempty(h))
   set(f, 'ResizeFcn', '');
   return;
end
ax = get(f, 'CurrentAxes');
nlines = get(ax, 'UserData');

% Position the axes so that the fake X tick labels have room to display
set(ax, 'Units', 'characters');
p = get(ax, 'Position');
ptop = p(2) + p(4);
if (p(4) < nlines+1.5)
   p(2) = ptop/2;
else
   p(2) = nlines + 1;
end
p(4) = ptop - p(2);
set(ax, 'Position', p);
set(ax, 'Units', 'normalized');

% Position the labels at the proper place
xl = get(gca, 'XLabel');
set(xl, 'Units', 'data');
p = get(xl, 'Position');
ylim = get(gca, 'YLim');
p2 = (p(2)+ylim(1))/2;
for j=1:length(h)
   p = get(h(j), 'Position') ;
   p(2) = p2;
   set(h(j), 'Position', p);
end
end
