% /* infors_visnet.m
% Derived from infors_visnet.c 9 August 2017 E T Rolls
% copyright E T Rolls, Oxford Centre for Computational Neuroscience, https://www.oxcns.org
% No warranty express or implied is provided: this is research software.
% If this software is used or adapted for new published research, 
% it is requested that the original publication, which provides details, is cited.
% This information theoretic analysis method for single neuron activity is described by:
% Rolls,E.T., Treves,A., Tovee,M. and Panzeri,S. (1997) 
% Information in the neuronal representation of individual stimuli in the primate temporal visual cortex. 
% Journal of Computational Neuroscience 4: 309-333.
%
% This software is made available with Rolls, E. T. (2021) Brain Computations: What and How. Oxford University Press, Oxford. https://www.oxcns.org
% Some of the coding in this program reflects the original C program

%    This program is used to extract the amount of information, in bits, that rate responses of
%    individual cells carry about a set of stimuli, and sparseness measures  
%    developed by Alessandro Treves, ale@limbo.sissa.it, with
%    Edmund Rolls, Stefano Panzeri, Bill Skaggs, during 1994-95  */

% timewin can be set below to a value other than 1000 ms if the number of spikes in a timewin is supplied% in a 
% This means that the rates in the mr file will be in spikes/s
% This program requires phims.m

format compact; 

max_c = 65536;    % max no of cells, was 1024, then 16384 for 128x128 VisNet then 65536 for 256x256 VisNet */
max_s = 50;	      % max no of stimuli */
MAX_B = 20;       % max no of bins for the firing rate distribution, normally up to 25, but for VisNet can be as small as 3 or 2 */
max_b = 3;       % The actual number of bins into which the rates are divided. Was 25. Set this carefully. Look at the results.
if max_b > MAX_B; max_b = MAX_B; end

% Read in the data for a file. Select the correct place for infos_matlab.mat

load('../VNanalysis/infos_matlab.mat'); % this loads num_c the number of cells, num_s the number of stimuli, and infom. OR
%load('infos_matlab.mat');
% infom has 1 row for each trial
% Each row has the stimulus number in the range 1 : num_s, the number of stimuli, followed by num_c firing rates
% The trials can be in any order. There can be different numbers of trials for each stimulus

%num_c = SQNET_SIZE; in VisNet
%num_s = N_GROUPS;
fprintf('number of cells=%4d   number of stimuli=%4d\n', num_c, num_s);
if num_c > max_c
    fprintf('In infors_visnet max_c is %d and is too low. Edit infors_visnet.m.  Exiting\n', max_c);
    return
end
if num_s > max_s
    fprintf('In infors_visnet max_s is %d and is too low. Edit infors_visnet.m.  Exiting\n', max_s);
    return
end
outfile_name = '../VNanalysis/infos.out'; % This is for VisNet. If you do not have this directory
%outfile_name = 'infos.out'; % use this

rcs= zeros(max_c, max_s);  %	/* prob distr mean rates */
acs = zeros(max_c, max_s); %	/* prob distr nonzero rates */
ra = zeros(max_c, max_s);  %	/* rate averages */
ra_true = zeros(max_c, max_s); %/* rate averages before rescaling */
sd = zeros(max_c, max_s);  %    /* rate variances */
sd_true = zeros(max_c, max_s); %/* rate variances before rescaling */
f0 = zeros(max_c, max_s);  %    /* fraction of zero rates */

Ps = zeros(max_s, 1);      %    /* Prob(s)    */
s0err = 0;
timewin = 1000;            % if this is not 1000, the spike numbers will be converted to firing rates using this timewin in ms 
tmp = 0.0;
r_a1 = zeros(max_c, 1);
r_a2 = zeros(max_c, 1);    %	/* averages for sparseness */
r_a1_true = zeros(max_c, 1); %	/* before rescaling  */
sp = zeros(max_c, 1);      % sparseness
lsp = zeros(max_c, 1);     %	/* log sparseness - see Rolls and Treves 1998 eqn A3.19 [a ln(1/a)]
It = zeros(max_c, 1);      % info derivative
It_av = 0;	               %    /* info derivative average */
Qsq = zeros(max_s, MAX_B); % q1;	/* quantized probability table ETR 2010 */
Qq = zeros(MAX_B, 1); %		/* quantized Prob(s_q) ETR 2010 */
Psq = zeros(max_s, MAX_B); % p1;	/* probability table ETR 2010 */
Pq = zeros(MAX_B, 1); %		/* Prob(s_q)  ETR 2010 */
nrb = zeros(max_c, max_s, MAX_B); %	/* number of rates per bin */
nrb0 = zeros(max_c, MAX_B); %	/* total number of rates per bin */
I_raws = zeros(max_s, 1);
I_C1s = zeros(max_s, 1); % /* first correction term from data */
I_means = zeros(max_s, 1);
eps = 0.000001;	% /* small number */
%eps = 0.00000000001;	% ETR matlab /* small number */

outfile_name = 'infos.out';
inffile_name = 'res_infors';
mrfile_name = 'mr';
pltfile_name = 'plt';

max_t = 9; % 2 * max_s; This is the number of trials for each stimulus. In VisNet = N_VIEWS_TEST
min_t = 3; % 2 * max_s;
seed = 243342;

outfile = fopen(outfile_name, 'w');
inffile = fopen(inffile_name, 'w');
pltfile = fopen(pltfile_name, 'w');
mrfile = fopen(mrfile_name, 'w');

%    *************************************
%    read data and set constant parameters
%    ************************************* 

rmax_b = (max_b - 1.0) - eps;
rng(seed); % Comment out for Octave

nr = zeros(num_s, 1);
r = zeros(num_c, num_s, max_t); % t is the number of trials for that cell and that stimulus. r is the rate file.
n_tot = 0; % total number of trials
for lines = 1: size(infom, 1) % each line or row is a trial
    s = infom(lines, 1);       % the stimulus number
    for c = 1 : num_c
        rate = round(infom(lines, c + 1)); % infors_visnet assumes integers for correct binning
%         if rate > 0
%             fprintf('rate = %d\n', rate);
%         end

        t = nr(s) + 1;
        if t < (max_t + 1)
            r(c, s, t) = rate;
        end
    end
    nr(s) = nr(s) + 1;
    n_tot = n_tot + 1;
end
fprintf('total number of trials read from infos_matlab.mat = %d\n', n_tot);


% /******  calculate true firing rates  (before rescaling)    *******/
for c = 1 : num_c
    r_a1_true(c) = 0.0;
    for s = 1 : num_s
        nt = nr(s);
        tmp = 0.0;
        ra_true(c, s) = 0.0;
        sd_true(c, s) = 0.0;
        for t = 1 : nt
            ra_true(c, s) = ra_true(c, s) + r(c, s, t) / nt; 
            sd_true(c, s) = sd_true(c, s) + (r(c, s, t) / nt) * (r(c, s, t) / nt); %/* accumulate the variance */
        end
        r_a1_true(c) = r_a1_true(c) + ra_true(c, s) * nt / n_tot;
    end
end

for c = 1 : num_c
    for s = 1 : num_s
        tmp = sqrt(abs((sd_true(c, s) - ra_true(c, s) * ra(c, s) / nt) / (nt - 1.0))); % /* sd */
        tmp = tmp / (timewin / 1000); %		/* now in spikes/s */
        sd_true(c, s) = tmp / sqrt(nt);	% /* standard error of the mean */
        ra_true(c, s) = ra_true(c, s) / (timewin / 1000);	%/* now in spikes/sec. */
    end
    r_a1_true(c) = r_a1_true(c) / (timewin / 1000); %	/* now in spikes/sec   */
end

%fprintf(mrfile, '# mean firing rates (in spikes/sec) timewin=%6.1f ms\n', timewin);
fprintf(mrfile, '# cell, mean f.r. of the cell, mean f.r. to each stimulus\n');
fprintf(mrfile, '# cell, mean f.r. of the cell, sem       to each stimulus\n');
for c = 1 : num_c
    fprintf(mrfile, '%2d ', c);
    fprintf(mrfile, '%6.3f ', r_a1_true(c));
    for s = 1 : num_s
        fprintf(mrfile, '%6.3f ', ra_true(c, s));
    end
    fprintf(mrfile, '\n');
    fprintf(mrfile, '%2d ', c);
    fprintf(mrfile, '%6.3f ', r_a1_true(c));
    for s = 1 : num_s %		/* prints a second row for the same cell with the sems for each stimulus */
        fprintf(mrfile, '%6.3f ', sd_true(c, s));
    end
    fprintf(mrfile, '\n');
end
fclose(mrfile);
fprintf('Closed mrfile\n');
% /* *** end of calculation of true firing frequencies */
% 
It_av = 0.0;
for c = 1 : num_c
    r_max = 0.0;
    It(c) = 0.0;
    r_a1(c) = 0.0;
    r_a2(c) = 0.0;
    for s = 1 : num_s
        nt = nr(s);
        for t = 1 : nt
            if r(c, s, t) > r_max
                r_max = r(c, s, t);
            end
        end
    end
%    fprintf('c=%d  r_max=%f rmax_b=%f\n', c, r_max, rmax_b); % debugging
    for s = 1 : num_s
        nt = nr(s);
        ra(c, s) = 0;
        sd(c, s) = 0;
        f0(c, s) = 0;
        for b = 1 : max_b
            nrb(c, s, b) = 0;
        end
        for t = 1 : nt
            if r_max > rmax_b
                r(c, s, t) = r(c, s, t) * (rmax_b / r_max);
            end
            b = 0;
            uppe = eps;
            while r(c, s, t) > uppe
                b = b + 1;
                uppe = b + eps;
            end
            nrb(c, s, b+1) = nrb(c, s, b+1) + 1;
            nrb0(c, b+1) = nrb0(c, b+1) + 1;
            
 	        ra(c, s) = ra(c, s) + (r(c, s, t) / nt);
            sd(c, s) = sd(c, s) + ((r(c, s, t) * r(c, s, t)) / (nt - 1.0));
            if r(c, s, t) < eps
                f0(c, s) = f0(c, s) + (1. / nt);
            end
            %if c == 12 && s == 2 % debug
             %   fprintf('cell=%d  s=%d  t=%d  r=%f  rmax_b=%f  r_max=%f  b=%d\n' , c, s, t, r(c, s, t), rmax_b, r_max, b);
            %end
        end
        fprintf(outfile, 'c=%d s=%d   ', c, s);
        for b = 1 : max_b
            fprintf(outfile, '    %d', nrb(c, s, b));
        end
        fprintf(outfile, '\n');
        sd(c, s) = sd(c, s) - (ra(c, s) * ra(c, s) * nt / (nt - 1.0));	% /* back to sd (AT) */
        acs(c, s) = 1.0;
        if ra(c, s) <= (0.2 / nt)
            rcs(c, s) = 0.2 / nt;
        elseif sd(c, s) <= ra(c, s)
            rcs(c, s) = ra(c, s);
        else
            rcs(c, s) = sd(c, s) / ra(c, s) - 1.0 + ra(c, s);
            acs(c, s) = ra(c, s) / rcs(c, s);
        end
        r_a1(c) = r_a1(c) + (ra(c, s) * nt / n_tot);
        r_a2(c) = r_a2(c) + (ra(c, s) * ra(c, s) * nt / n_tot);
        if ra(c, s) > eps
            It(c) = It(c) + (log(ra_true(c, s)) * ra_true(c, s) * nt / n_tot);
        end
        fprintf(outfile, 'Poisson    ');
        fact = acs(c, s) * exp(-rcs(c, s));
        fprintf(outfile, ' %4.0f', (fact + 1. - acs(c, s)) * nt);
        for b = 1 : max_b
            fact = fact * (rcs(c, s) / b);
            fprintf(outfile, ' %4.0f', fact * nt);
        end
        fprintf(outfile, '\n');
        fprintf(outfile, 'Gaussian   ');
        if f0(c, s) > eps
            fact = f0(c, s);
        else
            fact = phims(0.0, ra(c,s), sd(c, s));
        end
        fprintf(outfile, ' %4.0f', fact * nt);
        uppe = 0.0;
        for b = 2 : max_b
            fact = -phims(uppe, ra(c, s), sd(c, s));
            uppe = uppe + 1.;
            fact = fact + (phims(uppe, ra(c, s), sd(c, s)));
            fprintf(outfile, ' %4.0f', fact * nt);
        end
        fprintf(outfile, '\n');
    end
 
    if r_a1(c) > eps
        It(c) = It(c) - (log(r_a1_true(c)) * r_a1_true(c));
    end
    It(c) = It(c) / log(2.0);
    It_av = It_av + (It(c) / num_c);
end

% 
for s = 1 : num_s
    rt = nr(s);
    Ps(s) = rt / n_tot;
    fprintf(outfile, ' s=%2d  nr=%3d  P(s)=%5.3f\n', s, nr(s), Ps(s));
    fprintf(' s=%2d  nr=%3d  P(s)=%5.3f\n', s, nr(s), Ps(s));
end
fprintf('Assigned probs for each stimulus\n');
% 
I_mean_raw = 0.;
I_mean_C1 = 0.;
dp = 1. / n_tot;
for s = 1 :  num_s
    I_means(s) = 0.0;
end
% 
%    /* fprintf(pltfile, "# information from single cells\n"); */
%    //fprintf(pltfile, "# cell   Stim  Stim_Spec_Info  Rate   Rate_Std \n");
% 
% /* ***********************************
%    loop calculating info for each cell
%    ***********************************    */
for c = 1: num_c
    I_raw = 0.;
    I_C1 = 0.;
    for s = 1 : num_s
        I_raws(s) = 0.;
        I_C1s(s) = 0.;
        fact = acs(c, s) * exp(-rcs(c, s));
        Psq(s, 1) = Ps(s) * (fact + 1. - acs(c, s));
        Qsq(s, 1) = nrb(c, s, 1) / n_tot;
        for b = 1 : max_b
            fact = fact * (rcs(c, s) / b);
            Psq(s, b) = Ps(s) * fact;
            Qsq(s, b) = nrb(c, s, b) / n_tot;
        end
    end
    for b = 1 : max_b
        Pq(b) = 0.;
        Qq(b) = 0.;
        for s = 1 : num_s
            Pq(b) = Pq(b) + Psq(s, b);
            Qq(b) = Qq(b) + Qsq(s, b);
        end
    end
    sp(c) = 0.0;
    lsp(c) = 0.0;
    if r_a2(c) > 0.0
        sp(c) = r_a1(c) * r_a1(c) / r_a2(c);
        lsp(c) = -sp(c) * log(sp(c));
    end

    for s = 1 : num_s
        nt = nr(s);
        nb = 0;
        for b = 1 : max_b
            q1 = Qsq(s, b);
            if (q1 > eps)
                I_raw = I_raw + (q1 * log(q1 / (Qq(b) * Ps(s))));
                I_raws(s) = I_raws(s) + ((q1 / Ps(s)) * log(q1 / (Qq(b) * Ps(s))));
                nb = nb + 1;
            end
        end
        if nb < max_b
            nb_x = 0.0;
            for b = 1 : max_b
                qc_x = ((Qsq(s, b) / Ps(s) - eps) * nt + 1.) / (nt + nb);
                if (Qsq(s, b) > eps)
  		            nb_x = nb_x + (1. - exp(log(1. - qc_x) * nt));
                end
            end
            delta_N_prev = max_b * max_b;
            delta_N = (nb - nb_x) * (nb - nb_x);
            xtr = 0;
            while delta_N < delta_N_prev && (nb + xtr) < max_b
                xtr = xtr+ 1;
                nb_x = 0.0;
                gg = log(1. + 0.8 * nb / nt) * xtr / nt;
                xxx = 0;
                for b = 1 : max_b
                    if Qsq(s, b) > eps
                        qc_x = (1. - gg) * ((Qsq(s, b) / Ps(s)) * nt + 1.) / (nt + nb);
     		            nb_x = nb_x + (1. - exp(log(1. - qc_x) * nt));
                    elseif (xxx < xtr)
                        qc_x = gg / xtr;
                        nb_x = nb_x + (1. - exp(log(1. - qc_x) * nt));
 		                xxx = xxx + 1;
                    end
                end
                delta_N_prev = delta_N;
                delta_N = (nb - nb_x) * (nb - nb_x);
            end
            nb = nb + xtr - 1;
            if delta_N < delta_N_prev
                nb = nb + 1;
            end
        end

        if Ps(s) > eps
            I_C1s(s) = (1. / Ps(s)) * (nb - 1.) - 1.;
        end
        I_C1 = I_C1 + nb;
        for b = 1 : max_b
            if ((Ps(s) > eps) && (Qsq(s, b) > eps))
                q1 = Qsq(s, b) / Ps(s);
                I_C1s(s) = I_C1s(s) + ((-q1 + 2 * q1 * q1) / (Qq(b)));
            end
        end
        I_C1s(s) = I_C1s(s) / (2. * n_tot *log(2.));
 
        I_raws(s) = I_raws(s) / log(2.);
    end
    for b = 1: max_b
        if (Qq(b) > eps)
 	        I_C1 = I_C1 - 1.;
        end
    end
    
    I_raw = I_raw / log(2.);
    I_mean_raw = I_mean_raw + I_raw / num_c;
    I_C1 = I_C1 - (num_s - 1.);
    I_C1 = I_C1 * (dp / (2. * log(2.)));
    I_mean_C1 = I_mean_C1 + (I_C1 / num_c);
 
    for s = 1 : num_s
        I_means(s) = I_means(s) + ((I_raws(s) - I_C1s(s)) / num_c);
    end
%       
    fprintf(outfile, 'cell=%d spars=%4.2f lsp=%6.2f %5.3f\n', c, sp(c), lsp(c), It(c));
    fprintf(outfile, 'RawInfo=%5.3f C_1=%5.3f I=%5.3f\n', I_raw, I_C1, I_raw - I_C1);
    %fprintf('cell=%2d sparseness(a)=%4.2f lsp=%6.2f %5.3f RawInfo=%5.3f Corr_1=%6.3f Inf=%5.3f bits\n', c, sp(c), lsp(c), It(c), I_raw, I_C1, I_raw - I_C1);
    %fprintf('cell=%2d sparseness(a)=%4.2f RawInfo=%5.3f Corr_1=%6.3f CorrectedInfo=%5.3f bits\n', c, sp(c), I_raw, I_C1, I_raw - I_C1);
    fprintf(inffile, '%d %5.3f %5.3f %5.3f \n', c, I_raw, I_C1, I_raw - I_C1);
%       /* Modified 29 March 2013 to add the mean rate for each stimulus, and a RateDifference */
    SumRates = 0.0;
    for s = 1 :num_s
        SumRates = SumRates + ra_true(c, s);
    end
    SumRates = SumRates / num_s; % /* Nov 2014 */
    for s = 1 : num_s
        RateDifference = (ra_true(c, s)) - SumRates; % /* Large positive values will indicate neurons responding with a high rate to the best stimulus, and low rates to the others Nov 2014*/ 
        fprintf(pltfile, '%5d %2d %4.3f %9.2f %9.2f\n', c, s, I_raws(s), ra_true(c, s), RateDifference);
    end
end

% 
fprintf(outfile, 'MeanRawInfo=%5.3f Mean C_1=%5.3f Mean I=%5.3f It=%5.3f\n', I_mean_raw, I_mean_C1, I_mean_raw - I_mean_C1, It_av);
fclose(outfile);
%    /*
%    fprintf(pltfile, "# Mean values of the cell population \n");
%    fprintf(pltfile, "   %6.3f ", I_mean_raw - I_mean_C1);
%    for (s = 0; s < num_s; s++)
%       fprintf(pltfile, "%6.3f ", I_means[s]);
%    fprintf(pltfile, "\n");
%    */
fclose(pltfile);
fclose(inffile);
  