import base64
import io
import json
import tarfile
import warnings
import itertools
import math
import bokehCssPTB
from dsiParser import dsiParser
from bokeh.plotting import curdoc,figure
from bokeh.layouts import column, row
from bokeh.models import FileInput, Div, CustomJS, Button, TabPanel, Tabs, Dropdown, TextInput, Button, MathText, Label, Arrow, NormalHead
from bokeh.palettes import Category10
from bokeh.events import ValueSubmit
import numpy as np
colors=Category10[10]
VERSION = "0.1.0"
class dsiparserInput():

    def __init__(self,defaultInput=""):
        self.dsiInput = TextInput(value=defaultInput, title="DSI unit string:", width=500)
        self.dsiInput.on_event(ValueSubmit, self.parseInput)
        self.dsiSubmitButton = Button(label="Convert", button_type="primary")
        self.dsiSubmitButton.on_click(self.parseInput)
        self.inputRow = row(children = [self.dsiInput, self.dsiSubmitButton], css_classes = ["textInputRow"])
        self.results = column(children = [])
        self.widget = column(children=[self.inputRow, self.results], css_classes=["textlikeColumn"])
        self.valideUnit = False

    def parseInput(self):
        self.results.children = []
        input = self.dsiInput.value
        p = dsiParser()
        resultTree = p.parse(input)

        parsingMessages = []
        if resultTree.valid:
            parsingMessages.append(
                Div(
                    text = "DSI string parsed without warnings",
                    css_classes = ["msg-positive"],
                )
            )
        else:
            for message in resultTree.warnings:
                parsingMessages.append(
                    Div(
                        text = message,
                        css_classes = ["msg-negative"]
                    )
                )
        self.resultErrors = column(children = parsingMessages)

        # latexOutput = TextInput(value=resultTree.toLatex(), title="DSI unit string:")
        # latexOutput.js_on_change()
        latexOutput = row(children=[
            Div(text = "$$\mathrm{\LaTeX{}}$$ code:"),
            Div(text = "<pre><code>"+resultTree.toLatex()+"</code></pre>", disable_math=True)
        ])

        p = figure(width=500, height=100, x_range=(0, 1), y_range=(0, 1), tools="save")
        p.xaxis.visible = False
        p.yaxis.visible = False
        p.grid.visible = False

        p.add_layout(Label(
            x=0.5, y=0.5, text=resultTree.toLatex(), text_font_size="12px", text_baseline="middle", text_align="center",
        ))

        imageOutput = row(children = [
            Div(text = "$$\mathrm{\LaTeX{}}$$ output:"),
            p
        ], css_classes = ["latexImageRow"])
        self.results.children = [self.resultErrors, latexOutput, imageOutput]
        self.dsiTree = resultTree
        if resultTree.valid:
            self.valideUnit = True

def removeArrowsAndLabelsFromPlot(plot):
    for location in ['left', 'right', 'above', 'below', 'center']:
        for plotObj in [plot,plot.plot]:
            for layout in plotObj.__getattr__(location):
                if isinstance(layout, Arrow) or isinstance(layout, Label):
                    plotObj.__getattr__(location).remove(layout)

class dsiCompGraphGen:

    def flush(self):
        # Assuming 'plot' is your Bokeh figure
        renderers = self.plot.renderers
        for r in renderers:
            self.plot.renderers.remove(r)
        removeArrowsAndLabelsFromPlot(self.plot)
        removeArrowsAndLabelsFromPlot(self.plot)# only second call removes the three arrows generated last
        self.widget = row(self.plot)
        self.unitLabels = {}
        self.arrows = {}
        self.scalFactorLables = {}

    def __init__(self,treeA=None,treeB=None):
        self.treeA=treeA
        self.treeB=treeB
        self.plot = figure(width=500, height=500, x_range=(-1, 1), y_range=(0, 1), tools="save")
        self.plot.xaxis.visible = False
        self.plot.yaxis.visible = False
        self.plot.grid.visible = False
        self.flush()
        self.widget=row(self.plot)

    def reDraw(self,treeA=None,treeB=None):
        self.treeA=treeA
        self.treeB=treeB
        self.scalfactorAB, self.baseUnit = self.treeA.isScalablyEqualTo(self.treeB)
        self.scalfactorABase, baseUnitABase = self.treeA.isScalablyEqualTo(self.baseUnit)

        self.scalfactorBaseA, baseUnitBaseA = self.baseUnit.isScalablyEqualTo(self.treeA)
        self.scalfactorBaseB, baseUnitbaseB = self.baseUnit.isScalablyEqualTo(self.treeB)

        self.scalfactorBA, baseUnitBA = self.treeB.isScalablyEqualTo(self.treeA)
        self.scalfactorBBase, baseUnitBBase = self.treeB.isScalablyEqualTo(self.baseUnit)
        self.coordinatList = [{'coords':(-0.8, 0.8),'baseUnit':self.treeA,'name':'A','text_baseLine':'middle','text_align':"right"},
                              {'coords': (0.8, 0.8), 'baseUnit': self.treeB,'name':'B','text_baseLine':'middle','text_align':"left"},
                              {'coords': (0.0, 0.2), 'baseUnit': self.baseUnit,'name':'Base','text_baseLine':'top','text_align':"center"},]

        for unitToDraw in self.coordinatList:
            x=unitToDraw['coords'][0]
            y = unitToDraw['coords'][1]
            tree=unitToDraw['baseUnit']
            name=unitToDraw['name']
            if not name in self.unitLabels:
                self.unitLabels[name]=Label(x=x, y=y, text=tree.toLatex(), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],)
                self.plot.add_layout(self.unitLabels[name])
            else:
                self.unitLabels[name].text=tree.toLatex()
        colIDX=0
        for quant1, quant2 in itertools.combinations(self.coordinatList, 2):
            x1 = quant1['coords'][0]
            y1 = quant1['coords'][1]
            name1 = quant1['name']

            x2 = quant2['coords'][0]
            y2 = quant2['coords'][1]
            name2 = quant2['name']

            if not name1+'_'+name2 in self.arrows:
                nh = NormalHead(fill_color=colors[colIDX], fill_alpha=0.5, line_color=colors[colIDX])
                self.arrows[name1+'_'+name2]=Arrow(end=nh, line_color=colors[colIDX], line_dash=[15, 5],x_start=x1, y_start=y1, x_end=x2, y_end=y2)
                self.plot.add_layout(self.arrows[name1+'_'+name2])
            scale12, baseUnit = quant1['baseUnit'].isScalablyEqualTo(quant2['baseUnit'])
            if not name1 + '_' + name2 in self.scalFactorLables:
                angle_deg = -1*np.arctan2(y2 - y1, x2 - x1)
                if abs(angle_deg)>np.pi/8:
                    angle_deg+=-np.pi/2
                self.scalFactorLables[name1 + '_' + name2]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2]), y=np.abs(y1-y2)/2+np.min([y1,y2]), text="{:.4g}".format(scale12), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX],angle=angle_deg)
                self.plot.add_layout(self.scalFactorLables[name1 + '_' + name2])
            else:
                self.scalFactorLables[name1 + '_' + name2].text="{:.4g}".format(scale12)

            if not name2 + '_' + name1 in self.arrows:
                nh = NormalHead(fill_color=colors[colIDX+1], fill_alpha=0.5, line_color=colors[colIDX+1])
                self.arrows[name2+'_'+name1]=Arrow(end=nh, line_color=colors[colIDX+1], line_dash=[15, 5], x_start=x2, y_start=y2+0.05, x_end=x1, y_end=y1+0.05)
                self.plot.add_layout(self.arrows[name2+'_'+name1])

            scale21, baseUnit = quant2['baseUnit'].isScalablyEqualTo(quant1['baseUnit'])
            if not name2 + '_' + name1 in self.scalFactorLables:
                angle_deg = -1*np.arctan2(y2 - y1, x2 - x1)
                if abs(angle_deg)>np.pi/8:
                    angle_deg+=-np.pi/2
                self.scalFactorLables[name2 + '_' + name1]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2])+0.05, y=np.abs(y1-y2)/2+np.min([y1,y2])+0.05, text="{:.4g}".format(scale21), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX+1],angle=angle_deg)
                self.plot.add_layout(self.scalFactorLables[name2 + '_' + name1])
            else:
                self.scalFactorLables[name2+ '_' + name1].text="{:.4g}".format(scale21)
            colIDX+=2

class page():
    def __init__(self):
        curdoc().template_variables["VERSION"] = VERSION
        curdoc().title = "DSI to Latex"
        curdoc().add_root(bokehCssPTB.getStyleDiv())
        curdoc().theme = bokehCssPTB.getTheme()
        self.dsiInput1 = dsiparserInput(defaultInput="\\milli\\newton\\metre")
        self.dsiInput2 = dsiparserInput(defaultInput="\\kilo\\joule")
        self.inputs=row([self.dsiInput1.widget, self.dsiInput2.widget])
        curdoc().add_root(self.inputs)

        self.comapreButton = Button(label="Compare", button_type="primary")
        self.comapreButton.on_click(self.compare)
        self.compaReresult = Div(text = "", css_classes = ["msg-positive"])
        self.compareRow = row(children = [self.comapreButton,self.compaReresult], css_classes = ["textInputRow"])
        curdoc().add_root(self.compareRow)
        self.dsiCompGraphGen=dsiCompGraphGen(self.dsiInput1,self.dsiInput2)
        curdoc().add_root(self.dsiCompGraphGen.widget)

    def compare(self):
        self.dsiInput1.parseInput()
        self.dsiInput2.parseInput()
        try:
            scalfactor,baseUnit=self.dsiInput1.dsiTree.isScalablyEqualTo(self.dsiInput2.dsiTree)
            if not math.isnan(scalfactor):
                self.compaReresult.text = "The two units are equal up to a scaling factor of "+str(scalfactor)+" and a base unit of "+str(baseUnit)
                self.compaReresult.css_classes=["msg-positive"]
            else:
                self.compaReresult.text = "The two units are not equal"
                self.compaReresult.css_classes = ["msg-negative"]
            if self.dsiInput1.valideUnit and self.dsiInput2.valideUnit:
                self.dsiCompGraphGen.reDraw(self.dsiInput1.dsiTree,self.dsiInput2.dsiTree)
            else:
                self.dsiCompGraphGen.flush()
        except AttributeError as Ae:
            warnings.warn("AttributeError: "+str(Ae))
            self.compaReresult.text = "The two units are not equal"
            self.compaReresult.css_classes = ["msg-negative"]
            self.dsiCompGraphGen.flush()
thisPage = page()