# -*- coding: utf-8 -*-
"""
Created on Mon Mar 16 15:42:58 2020

@author: rieman01

Code to obtain the Bland-Altman plots for the spectral shape for the HS-, GOIA-,
and WURST-SPECIAL sequence variants.
"""

import matplotlib.pyplot as plt
import numpy as np
import glob
import matplotlib
from xlwt import Workbook

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 15}

matplotlib.rc('font', **font)
plt.close('all')

def bland_altman_plot(data1, data2, a,b,col,coli, *args, **kwargs):
	"""to generate points in BA plots for the differnet pulses and scenarios; 
	blue = HS, orange = GOIA, green = WURST, upper row: R_0, middle row: R_1,Mc,
	bottom row: R_1,Wc """
	data1     = np.asarray(data1)
	data2     = np.asarray(data2)
	meani      = np.mean(np.sum([abs(data1), abs(data2)], axis=1))
	diffabs=np.sum(abs(data1))-np.sum(abs(data2))
	axs[a,b].scatter(meani, diffabs, color=col,marker='o',edgecolor=coli, *args, **kwargs)
	return diffabs
	
def alterman(diff,a,b):
	"""to generate mean and +/- 1.96*SD line in plots"""
	md        = np.mean(diff)            # Mean of the difference
	sd        = np.std(diff,  ddof=1)    # Standard deviation of the difference
	print(np.round(md,3),np.round(sd,4))
	axs[a,b].axhline(md,           color='red', linestyle='--')
	axs[a,b].axhline(md + (1.96*sd), color='gray', linestyle='--')
	axs[a,b].axhline(md - (1.96*sd), color='gray', linestyle='--')	
	
files= glob.glob('*.npy')
data=[]
for i in files:
	data.append(np.load(i))
data=np.asarray(data)
data=np.reshape(data,[int(len(data)/12),3,4,np.shape(data)[1]])
fig, axs = plt.subplots(3, 3,sharex='all', sharey='all')
HS_same=[]
HS_repro=[]
HS_2sess=[]
GOIA_same=[]
GOIA_repro=[]
GOIA_2sess=[]
WURST_same=[]
WURST_repro=[]
WURST_2sess=[]
RM_same=[]
RM_repro=[]
RM_2sess=[]

colors=['blue','orange','green']
colori=['deepskyblue','orangered','limegreen']
for i in range(len(data)):
	for j in range(3):
		diff_same=bland_altman_plot(data[i][j][2], data[i][j][3],0,j,colors[j],colori[j]) #R_0
		diff_2sess3=bland_altman_plot(data[i][j][1], data[i][j][3],2,j,colors[j],colori[j]) #R_1,Mc		
		diff_repro=bland_altman_plot(data[i][j][0], data[i][j][1],1,j,colors[j],colori[j]) #R_1,Wc
		if j==1:
			GOIA_same.append(diff_same)
			GOIA_repro.append(diff_repro)
			GOIA_2sess.append(diff_2sess3)			
		elif j==0:
			HS_same.append(diff_same)
			HS_repro.append(diff_repro)
			HS_2sess.append(diff_2sess3)			
		elif j==2:
			WURST_same.append(diff_same)
			WURST_repro.append(diff_repro)
			WURST_2sess.append(diff_2sess3)			

for k in range(3):
	if k==1:
		alterman(GOIA_same,0,k)
		alterman(GOIA_repro,1,k)
		alterman(GOIA_2sess,2,k)
	elif k==0:
		alterman(HS_same,0,k)
		alterman(HS_repro,1,k)
		alterman(HS_2sess,2,k)
	elif k==2:
		alterman(WURST_same,0,k)
		alterman(WURST_repro,1,k)
		alterman(WURST_2sess,2,k)
			
fig.text(0.03, 0.55, '$BA_{i,y}$ / a.u.', ha='center', va='center', rotation='vertical')	
fig.text(0.55, 0.019, '$BA_{i,x}$ / a.u.', ha='center', va='center')
plt.subplots_adjust(top=0.987,bottom=0.14,left=0.14,right=0.987, wspace=0.05, hspace=0.05)		
plt.show()

data_neu=np.sum(np.absolute(data),axis=3) #to generate .xls file for the REML analysis
dat=['HS','GOIA','WURST']
rep=['ASR','ESR','WBSR']
sess=['1_1','1_2','2_1','2_2']
wb = Workbook()
sheet1 = wb.add_sheet('Sheet 1') 
sheet1.write(0,0,'Pulse')
sheet1.write(0,1,'Subject')
sheet1.write(0,2, 'Session')
sheet1.write(0,3,'Data')
nvol=9
kk=0
for i in range(3):
	sheet1.write(kk+1,0, dat[i])
	for j in range(nvol):
		sheet1.write(kk+1,1, j+1)
		for k in range(4):
			sheet1.write(kk+1,2, sess[k])
			sheet1.write(kk+1,3, data_neu[j,i,k])	
			kk+=1
wb.save('pulse_shape.xls')