import base64
import io
import json
import tarfile
import warnings
import itertools
import math
import bokehCssPTB
from dsiParser import dsiParser
from bokeh.plotting import curdoc,figure
from bokeh.layouts import column, row
from bokeh.models import FileInput, Div, CustomJS, Button, TabPanel, Tabs, Dropdown, TextInput, Button, MathText, Label, Arrow, NormalHead
from bokeh.palettes import Category10
from bokeh.events import ValueSubmit
import numpy as np
colors=Category10[10]
VERSION = "0.1.0"
class dsiparserInput():

    def __init__(self):
        self.dsiInput = TextInput(value="", title="DSI unit string:", width=500)
        self.dsiInput.on_event(ValueSubmit, self.parseInput)
        self.dsiSubmitButton = Button(label="Convert", button_type="primary")
        self.dsiSubmitButton.on_click(self.parseInput)
        self.inputRow = row(children = [self.dsiInput, self.dsiSubmitButton], css_classes = ["textInputRow"])
        self.results = column(children = [])
        self.widget = column(children=[self.inputRow, self.results], css_classes=["textlikeColumn"])

    def parseInput(self):
        self.results.children = []
        input = self.dsiInput.value
        p = dsiParser()
        resultTree = p.parse(input)

        parsingMessages = []
        if resultTree.valid:
            parsingMessages.append(
                Div(
                    text = "DSI string parsed without warnings",
                    css_classes = ["msg-positive"],
                )
            )
        else:
            for message in resultTree.warnings:
                parsingMessages.append(
                    Div(
                        text = message,
                        css_classes = ["msg-negative"]
                    )
                )
        self.resultErrors = column(children = parsingMessages)

        # latexOutput = TextInput(value=resultTree.toLatex(), title="DSI unit string:")
        # latexOutput.js_on_change()
        latexOutput = row(children=[
            Div(text = "$$\mathrm{\LaTeX{}}$$ code:"),
            Div(text = "<pre><code>"+resultTree.toLatex()+"</code></pre>", disable_math=True)
        ])

        p = figure(width=500, height=100, x_range=(0, 1), y_range=(0, 1), tools="save")
        p.xaxis.visible = False
        p.yaxis.visible = False
        p.grid.visible = False

        p.add_layout(Label(
            x=0.5, y=0.5, text=resultTree.toLatex(), text_font_size="12px", text_baseline="middle", text_align="center",
        ))

        imageOutput = row(children = [
            Div(text = "$$\mathrm{\LaTeX{}}$$ output:"),
            p
        ], css_classes = ["latexImageRow"])
        self.results.children = [self.resultErrors, latexOutput, imageOutput]
        self.dsiTree = resultTree


class dsiCompGraphGen:

    def __init__(self,treeA=None,treeB=None):
        self.treeA=treeA
        self.treeB=treeB
        self.plot= figure(width=500, height=500, x_range=(-1, 1), y_range=(0, 1), tools="save")
        self.plot.xaxis.visible = False
        self.plot.yaxis.visible = False
        self.plot.grid.visible = False
        self.unitLabels={}
        self.arrows={}
        self.scalFactorLables={}
        self.widget=row(self.plot)


    def reDraw(self,treeA=None,treeB=None):
        self.treeA=treeA
        self.treeB=treeB
        self.scalfactorAB, self.baseUnit = self.treeA.isScalablyEqualTo(self.treeB)
        self.scalfactorABase, baseUnitABase = self.treeA.isScalablyEqualTo(self.baseUnit)

        self.scalfactorBaseA, baseUnitBaseA = self.baseUnit.isScalablyEqualTo(self.treeA)
        self.scalfactorBaseB, baseUnitbaseB = self.baseUnit.isScalablyEqualTo(self.treeB)

        self.scalfactorBA, baseUnitBA = self.treeB.isScalablyEqualTo(self.treeA)
        self.scalfactorBBase, baseUnitBBase = self.treeB.isScalablyEqualTo(self.baseUnit)
        self.coordinatList = [{'coords':(-0.8, 0.8),'baseUnit':self.treeA,'name':'A','text_baseLine':'middle','text_align':"right"},
                              {'coords': (0.8, 0.8), 'baseUnit': self.treeB,'name':'B','text_baseLine':'middle','text_align':"left"},
                              {'coords': (0.0, 0.2), 'baseUnit': self.baseUnit,'name':'Base','text_baseLine':'top','text_align':"center"},]

        for unitToDraw in self.coordinatList:
            x=unitToDraw['coords'][0]
            y = unitToDraw['coords'][1]
            tree=unitToDraw['baseUnit']
            name=unitToDraw['name']
            if not name in self.unitLabels:
                self.unitLabels[name]=Label(x=x, y=y, text=tree.toLatex(), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],)
                self.plot.add_layout(self.unitLabels[name])
            else:
                self.unitLabels[name].text=tree.toLatex()
        colIDX=0
        for quant1, quant2 in itertools.combinations(self.coordinatList, 2):
            x1 = quant1['coords'][0]
            y1 = quant1['coords'][1]
            name1 = quant1['name']
            x2 = quant2['coords'][0]
            y2 = quant2['coords'][1]
            name2 = quant2['name']

            if not name1+'_'+name2 in self.arrows:
                nh = NormalHead(fill_color=colors[colIDX], fill_alpha=0.5, line_color=colors[colIDX])
                self.arrows[name1+'_'+name2]=Arrow(end=nh, line_color=colors[colIDX], line_dash=[15, 5],x_start=x1, y_start=y1, x_end=x2, y_end=y2)
                self.plot.add_layout(self.arrows[name1+'_'+name2])
            scale12, baseUnit = quant1['baseUnit'].isScalablyEqualTo(quant2['baseUnit'])
            if not name1 + '_' + name2 in self.scalFactorLables:
                self.scalFactorLables[name1 + '_' + name2]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2]), y=np.abs(y1-y2)/2+np.min([y1,y2]), text="{:.4g}".format(scale12), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX])
                self.plot.add_layout(self.scalFactorLables[name1 + '_' + name2])
            else:
                self.scalFactorLables[name1 + '_' + name2].text="{:.4g}".format(scale12)

            if not name2 + '_' + name1 in self.arrows:
                nh = NormalHead(fill_color=colors[colIDX+1], fill_alpha=0.5, line_color=colors[colIDX+1])
                self.arrows[name2+'_'+name1]=Arrow(end=nh, line_color=colors[colIDX+1], line_dash=[15, 5], x_start=x2, y_start=y2+0.05, x_end=x1, y_end=y1+0.05)
                self.plot.add_layout(self.arrows[name2+'_'+name1])

            scale21, baseUnit = quant2['baseUnit'].isScalablyEqualTo(quant1['baseUnit'])
            if not name2 + '_' + name1 in self.scalFactorLables:
                self.scalFactorLables[name2 + '_' + name1]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2])+0.05, y=np.abs(y1-y2)/2+np.min([y1,y2])+0.05, text="{:.4g}".format(scale21), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX+1])
                self.plot.add_layout(self.scalFactorLables[name2 + '_' + name1])
            else:
                self.scalFactorLables[name2+ '_' + name1].text="{:.4g}".format(scale21)
            colIDX+=2

class page():
    def __init__(self):
        curdoc().template_variables["VERSION"] = VERSION
        curdoc().title = "DSI to Latex"
        curdoc().add_root(bokehCssPTB.getStyleDiv())
        curdoc().theme = bokehCssPTB.getTheme()
        self.dsiInput1 = dsiparserInput()
        self.dsiInput2 = dsiparserInput()
        self.inputs=row([self.dsiInput1.widget, self.dsiInput2.widget])
        curdoc().add_root(self.inputs)

        self.comapreButton = Button(label="Compare", button_type="primary")
        self.comapreButton.on_click(self.compare)
        self.compaReresult = Div(text = "", css_classes = ["msg-positive"])
        self.compareRow = row(children = [self.comapreButton,self.compaReresult], css_classes = ["textInputRow"])
        curdoc().add_root(self.compareRow)
        self.dsiCompGraphGen=dsiCompGraphGen(self.dsiInput1,self.dsiInput2)
        curdoc().add_root(self.dsiCompGraphGen.widget)

    def compare(self):
        self.dsiInput1.parseInput()
        self.dsiInput2.parseInput()
        scalfactor,baseUnit=self.dsiInput1.dsiTree.isScalablyEqualTo(self.dsiInput2.dsiTree)
        if not math.isnan(scalfactor):
            self.compaReresult.text = "The two units are equal up to a scaling factor of "+str(scalfactor)+" and a base unit of "+str(baseUnit)
            self.compaReresult.css_classes=["msg-positive"]
        else:
            self.compaReresult.text = "The two units are not equal"
            self.compaReresult.css_classes = ["msg-negative"]
        self.dsiCompGraphGen.reDraw(self.dsiInput1.dsiTree,self.dsiInput2.dsiTree)
thisPage = page()