import base64
import io
import json
import tarfile
import warnings
import itertools
import math
import bokehCssPTB
from urllib.parse import quote
from dsiParser import dsiParser
from bokeh.plotting import curdoc,figure
from bokeh.layouts import column, row
from bokeh.models import FileInput, Div, CustomJS, Button, TabPanel, Tabs, Dropdown, TextInput, Button, MathText, Label, Arrow, NormalHead
from bokeh.palettes import Category10
from bokeh.events import ValueSubmit
import numpy as np
colors=Category10[10]
VERSION = "0.1.0"

issueTemplate="""There was an issue with the following unit conversion:

| Which unit | User entered input | Parsed input |
| ---------- | ------------------ | ------------ |
| Left unit  | {} | {} |
| Right unit | {} | {} |

### Result/Expectation: 

| Calculation | Result | Expectation | Did match |
|-------------|--------|-------------|-----------|
| Base unit   | {} |     | [ ]       |
| factor * left = right | {} |     | [ ] |
| factor * right = left | {} |     | [ ] |
| factor * left = base | {} |     | [ ] |
| factor * base = left | {} |     | [ ] |
| factor * right = base | {} |     | [ ] |
| factor * base = right | {} |     | [ ] |

If there was an error with the calculation, please fill out the table above. Feel free to also add additional info here: 

*Free text comment*"""

labelOffsetX=[0.0,0.0,-0.02,0.07,0.02,-0.07]
labelOffsetY=[-0.02,0.09,0,0.07,0,0.07]
class dsiparserInput():

    def __init__(self,defaultInput="",additionalComparisonCallbacks=[]):
        self.additionalComparisonCallbacks=additionalComparisonCallbacks
        self.dsiInput = TextInput(value=defaultInput, title="DSI unit string:", width=500)
        self.dsiInput.on_event(ValueSubmit, self.parseInput)
        self.dsiSubmitButton = Button(label="Convert", button_type="primary")
        self.dsiSubmitButton.on_click(self.parseInput)
        self.inputRow = row(children = [self.dsiInput, self.dsiSubmitButton], css_classes = ["textInputRow"])
        self.results = column(children = [])
        self.widget = column(children=[self.inputRow, self.results], css_classes=["doubleColumn"])
        self.valideUnit = False
        self.dsiTree = None

    def parseInpuitWithCallbacls(self):
        self.parseInput()
        for callback in self.additionalComparisonCallbacks:
            callback()
    def parseInput(self):
        self.results.children = []
        input = self.dsiInput.value
        p = dsiParser()
        resultTree = p.parse(input)

        parsingMessages = []
        if resultTree.valid:
            parsingMessages.append(
                Div(
                    text = "DSI string parsed without warnings",
                    css_classes = ["msg-positive"],
                )
            )
        else:
            for message in resultTree.warnings:
                parsingMessages.append(
                    Div(
                        text = message,
                        css_classes = ["msg-negative"]
                    )
                )
        self.resultErrors = column(children = parsingMessages)

        # latexOutput = TextInput(value=resultTree.toLatex(), title="DSI unit string:")
        # latexOutput.js_on_change()
        latexOutput = row(children=[
            Div(text = "$$\mathrm{\LaTeX{}}$$ code:"),
            Div(text = "<pre><code>"+resultTree.toLatex()+"</code></pre>", disable_math=True)
        ])

        p = figure(width=500, height=100, x_range=(0, 1), y_range=(0, 1), tools="save")
        p.xaxis.visible = False
        p.yaxis.visible = False
        p.grid.visible = False

        p.add_layout(Label(
            x=0.5, y=0.5, text=resultTree.toLatex(), text_font_size="30px", text_baseline="middle", text_align="center",
        ))

        imageOutput = row(children = [
            Div(text = "$$\mathrm{\LaTeX{}}$$ output:"),
            p
        ], css_classes = ["latexImageRow"])
        self.results.children = [self.resultErrors, latexOutput, imageOutput]
        self.dsiTree = resultTree
        if resultTree.valid:
            self.valideUnit = True

def removeArrowsAndLabelsFromPlot(plot):
    for location in ['left', 'right', 'above', 'below', 'center']:
        for plotObj in [plot,plot.plot]:
            for layout in plotObj.__getattr__(location):
                if isinstance(layout, Arrow) or isinstance(layout, Label):
                    plotObj.__getattr__(location).remove(layout)

class dsiCompGraphGen:

    def flush(self):
        # Assuming 'plot' is your Bokeh figure
        renderers = self.plot.renderers
        for r in renderers:
            self.plot.renderers.remove(r)
        removeArrowsAndLabelsFromPlot(self.plot)
        removeArrowsAndLabelsFromPlot(self.plot)# only second call removes the three arrows generated last
        self.widget = row(self.plot)
        self.unitLabels = {}
        self.arrows = {}
        self.scalFactorLables = {}

    def __init__(self,treeA=None,treeB=None):
        self.treeA=treeA
        self.treeB=treeB
        self.plot = figure(width=550, height=500, x_range=(-1, 1), y_range=(0, 1), tools="save")
        self.plot.xaxis.visible = False
        self.plot.yaxis.visible = False
        self.plot.grid.visible = False
        self.flush()
        self.widget=row(self.plot)

    def reDraw(self,treeA=None,treeB=None):
        self.treeA=treeA
        self.treeB=treeB
        self.scalfactorAB, self.baseUnit = self.treeA.isScalablyEqualTo(self.treeB)
        self.scalfactorABase, baseUnitABase = self.treeA.isScalablyEqualTo(self.baseUnit)

        self.scalfactorBaseA, baseUnitBaseA = self.baseUnit.isScalablyEqualTo(self.treeA)
        self.scalfactorBaseB, baseUnitbaseB = self.baseUnit.isScalablyEqualTo(self.treeB)

        self.scalfactorBA, baseUnitBA = self.treeB.isScalablyEqualTo(self.treeA)
        self.scalfactorBBase, baseUnitBBase = self.treeB.isScalablyEqualTo(self.baseUnit)
        self.coordinatList = [{'coords':(-0.8, 0.8),'baseUnit':self.treeA,'name':'A','text_baseLine':'middle','text_align':"right"},
                              {'coords': (0.8, 0.8), 'baseUnit': self.treeB,'name':'B','text_baseLine':'middle','text_align':"left"},
                              {'coords': (0.0, 0.2), 'baseUnit': self.baseUnit,'name':'Base','text_baseLine':'top','text_align':"center"},]

        for unitToDraw in self.coordinatList:
            x=unitToDraw['coords'][0]
            y = unitToDraw['coords'][1]
            tree=unitToDraw['baseUnit']
            name=unitToDraw['name']
            if not name in self.unitLabels:
                self.unitLabels[name]=Label(x=x, y=y, text=tree.toLatex(), text_font_size="24px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],)
                self.plot.add_layout(self.unitLabels[name])
            else:
                self.unitLabels[name].text=tree.toLatex()
        colIDX=0
        for quant1, quant2 in itertools.combinations(self.coordinatList, 2):
            x1 = quant1['coords'][0]
            y1 = quant1['coords'][1]
            name1 = quant1['name']

            x2 = quant2['coords'][0]
            y2 = quant2['coords'][1]
            name2 = quant2['name']

            if not name1+'_'+name2 in self.arrows:
                nh = NormalHead(fill_color=colors[colIDX], fill_alpha=0.5, line_color=colors[colIDX])
                self.arrows[name1+'_'+name2]=Arrow(end=nh, line_color=colors[colIDX], line_dash=[15, 5],x_start=x1, y_start=y1, x_end=x2, y_end=y2)
                self.plot.add_layout(self.arrows[name1+'_'+name2])
            scale12, baseUnit = quant1['baseUnit'].isScalablyEqualTo(quant2['baseUnit'])
            if not name1 + '_' + name2 in self.scalFactorLables:
                angle_deg = -1*np.arctan2(y2 - y1, x2 - x1)
                if abs(angle_deg)>np.pi/8:
                    angle_deg+=-np.pi/2
                self.scalFactorLables[name1 + '_' + name2]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2])+labelOffsetX[colIDX], y=np.abs(y1-y2)/2+np.min([y1,y2])+labelOffsetY[colIDX], text="{:.4g}".format(scale12), text_font_size="24px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX],angle=angle_deg)
                self.plot.add_layout(self.scalFactorLables[name1 + '_' + name2])
            else:
                self.scalFactorLables[name1 + '_' + name2].text="{:.4g}".format(scale12)

            if not name2 + '_' + name1 in self.arrows:
                nh = NormalHead(fill_color=colors[colIDX+1], fill_alpha=0.5, line_color=colors[colIDX+1])
                self.arrows[name2+'_'+name1]=Arrow(end=nh, line_color=colors[colIDX+1], line_dash=[15, 5], x_start=x2, y_start=y2+0.05, x_end=x1, y_end=y1+0.05)
                self.plot.add_layout(self.arrows[name2+'_'+name1])

            scale21, baseUnit = quant2['baseUnit'].isScalablyEqualTo(quant1['baseUnit'])
            if not name2 + '_' + name1 in self.scalFactorLables:
                angle_deg = -1*np.arctan2(y2 - y1, x2 - x1)
                if abs(angle_deg)>np.pi/8:
                    angle_deg+=-np.pi/2
                self.scalFactorLables[name2 + '_' + name1]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2])+labelOffsetX[colIDX+1], y=np.abs(y1-y2)/2+np.min([y1,y2])+labelOffsetY[colIDX+1], text="{:.4g}".format(scale21), text_font_size="24px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX+1],angle=angle_deg)
                self.plot.add_layout(self.scalFactorLables[name2 + '_' + name1])
            else:
                self.scalFactorLables[name2+ '_' + name1].text="{:.4g}".format(scale21)
            colIDX+=2

class page():
    def __init__(self):
        curdoc().template_variables["VERSION"] = VERSION
        curdoc().title = "DSI to Latex"
        curdoc().add_root(bokehCssPTB.getStyleDiv())
        curdoc().theme = bokehCssPTB.getTheme()
        self.dsiInput1 = dsiparserInput(defaultInput="\\milli\\newton\\metre",additionalComparisonCallbacks=[self.clearComparison,self.tryComparison])
        self.dsiInput2 = dsiparserInput(defaultInput="\\kilo\\joule",additionalComparisonCallbacks=[self.clearComparison,self.tryComparison])
        self.inputs=row([self.dsiInput1.widget, self.dsiInput2.widget])
        curdoc().add_root(self.inputs)

        self.comapreButton = Button(label="Compare", button_type="primary")
        self.comapreButton.on_click(self.compare)
        self.compaReresult = Div(text = "", css_classes = ["msg-positive"])
        self.compareRow = row(children = [self.comapreButton,self.compaReresult], css_classes = ["textInputRow"])
        curdoc().add_root(self.compareRow)
        self.dsiCompGraphGen=dsiCompGraphGen(self.dsiInput1,self.dsiInput2)
        curdoc().add_root(self.dsiCompGraphGen.widget)
        self.createIssueButton = Button(label="Report conversion error", disabled=True)
        # self.createIssueButton.on_click(self.createIssueUrl)
        curdoc().add_root(self.dsiCompGraphGen.widget)
        curdoc().add_root(self.createIssueButton)

    def compare(self):
        self.dsiInput1.parseInput()
        self.dsiInput2.parseInput()
        try:
            scalfactor,baseUnit=self.dsiInput1.dsiTree.isScalablyEqualTo(self.dsiInput2.dsiTree)
            if not math.isnan(scalfactor):
                self.compaReresult.text = "The two units are equal up to a scaling factor of "+str(scalfactor)+" and a base unit of "+str(baseUnit)
                self.compaReresult.css_classes=["msg-positive"]
            else:
                self.compaReresult.text = "The two units are not equal"
                self.compaReresult.css_classes = ["msg-negative"]
            if self.dsiInput1.valideUnit and self.dsiInput2.valideUnit:
                self.dsiCompGraphGen.reDraw(self.dsiInput1.dsiTree,self.dsiInput2.dsiTree)
            else:
                self.dsiCompGraphGen.flush()
        except AttributeError as Ae:
            warnings.warn("AttributeError: "+str(Ae))
            self.compaReresult.text = "The two units are not equal"
            self.compaReresult.css_classes = ["msg-negative"]
            self.dsiCompGraphGen.flush()
        self.createIssueButton.disabled=False
        self.createIssueButton.button_type="danger"
        self.createIssueButton.js_on_event("button_click",CustomJS(code=f"window.open('{self.createIssueUrl()}', '_blank');"))

    def createIssueUrl(self):
        issueArgs=[self.dsiInput1.dsiInput.value,str(self.dsiInput1.dsiTree),self.dsiInput2.dsiInput.value,str(self.dsiInput2.dsiTree)]
        comGenAtrssFroIssue=['baseUnit', 'scalfactorAB', 'scalfactorBA', 'scalfactorABase', 'scalfactorBaseA', 'scalfactorBBase', 'scalfactorBaseB']
        for comGenAtrss in comGenAtrssFroIssue:
            try:
                issueArgs.append(str(getattr(self.dsiCompGraphGen,comGenAtrss)))
            except AttributeError as Ae:
                issueArgs.append("AttributeError: "+str(Ae))
        #quantitiesToAdd=[self.dsiInput1.dsiInput.value,str(self.dsiInput1.dsiTree),self.dsiInput2.dsiInput.value,str(self.dsiInput2.dsiTree),str(self.dsiCompGraphGen.baseUnit),self.dsiCompGraphGen.scalfactorAB,self.dsiCompGraphGen.scalfactorBA,self.dsiCompGraphGen.scalfactorABase,self.dsiCompGraphGen.scalfactorBaseA,self.dsiCompGraphGen.scalfactorBBase,self.dsiCompGraphGen.scalfactorBaseB]
        #issueTemplate=open('./issue.md').read() #TODO add file inculde instead of the str....
        filledResult=issueTemplate.format(*issueArgs)
        filledTitle = f'Unexpected comparison result: {self.dsiInput1.dsiInput.value} to {self.dsiInput2.dsiInput.value}'
        issueUrl = r'https://gitlab1.ptb.de/digitaldynamicmeasurement/dsi-parser-frontend/-/issues/new?'
        title = quote(filledTitle)
        body = quote(filledResult)

        url = issueUrl + 'issue[title]=' + title + '&issue[description]=' + body

        return url

    def clearComparison(self):
        self.dsiCompGraphGen.flush()
        self.compaReresult.text = ""
        self.createIssueButton.disabled=True
        self.createIssueButton.button_type="primary"

    def tryComparison(self):
        if self.dsiInput1.valideUnit and self.dsiInput2.valideUnit:
            self.compare()

thisPage = page()