From 78f13dd3ca605b136d984c950c1b4a8addde3341 Mon Sep 17 00:00:00 2001
From: Benedikt Seeger <benedikt.seeger@ptb.de>
Date: Thu, 18 Jan 2024 09:54:13 +0100
Subject: [PATCH] improved robustness

---
 main.py | 68 ++++++++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 51 insertions(+), 17 deletions(-)

diff --git a/main.py b/main.py
index 588f807..6b72452 100644
--- a/main.py
+++ b/main.py
@@ -17,14 +17,15 @@ colors=Category10[10]
 VERSION = "0.1.0"
 class dsiparserInput():
 
-    def __init__(self):
-        self.dsiInput = TextInput(value="", title="DSI unit string:", width=500)
+    def __init__(self,defaultInput=""):
+        self.dsiInput = TextInput(value=defaultInput, title="DSI unit string:", width=500)
         self.dsiInput.on_event(ValueSubmit, self.parseInput)
         self.dsiSubmitButton = Button(label="Convert", button_type="primary")
         self.dsiSubmitButton.on_click(self.parseInput)
         self.inputRow = row(children = [self.dsiInput, self.dsiSubmitButton], css_classes = ["textInputRow"])
         self.results = column(children = [])
         self.widget = column(children=[self.inputRow, self.results], css_classes=["textlikeColumn"])
+        self.valideUnit = False
 
     def parseInput(self):
         self.results.children = []
@@ -72,23 +73,40 @@ class dsiparserInput():
         ], css_classes = ["latexImageRow"])
         self.results.children = [self.resultErrors, latexOutput, imageOutput]
         self.dsiTree = resultTree
+        if resultTree.valid:
+            self.valideUnit = True
 
+def removeArrowsAndLabelsFromPlot(plot):
+    for location in ['left', 'right', 'above', 'below', 'center']:
+        for plotObj in [plot,plot.plot]:
+            for layout in plotObj.__getattr__(location):
+                if isinstance(layout, Arrow) or isinstance(layout, Label):
+                    plotObj.__getattr__(location).remove(layout)
 
 class dsiCompGraphGen:
 
+    def flush(self):
+        # Assuming 'plot' is your Bokeh figure
+        renderers = self.plot.renderers
+        for r in renderers:
+            self.plot.renderers.remove(r)
+        removeArrowsAndLabelsFromPlot(self.plot)
+        removeArrowsAndLabelsFromPlot(self.plot)# only second call removes the three arrows generated last
+        self.widget = row(self.plot)
+        self.unitLabels = {}
+        self.arrows = {}
+        self.scalFactorLables = {}
+
     def __init__(self,treeA=None,treeB=None):
         self.treeA=treeA
         self.treeB=treeB
-        self.plot= figure(width=500, height=500, x_range=(-1, 1), y_range=(0, 1), tools="save")
+        self.plot = figure(width=500, height=500, x_range=(-1, 1), y_range=(0, 1), tools="save")
         self.plot.xaxis.visible = False
         self.plot.yaxis.visible = False
         self.plot.grid.visible = False
-        self.unitLabels={}
-        self.arrows={}
-        self.scalFactorLables={}
+        self.flush()
         self.widget=row(self.plot)
 
-
     def reDraw(self,treeA=None,treeB=None):
         self.treeA=treeA
         self.treeB=treeB
@@ -119,6 +137,7 @@ class dsiCompGraphGen:
             x1 = quant1['coords'][0]
             y1 = quant1['coords'][1]
             name1 = quant1['name']
+
             x2 = quant2['coords'][0]
             y2 = quant2['coords'][1]
             name2 = quant2['name']
@@ -129,7 +148,10 @@ class dsiCompGraphGen:
                 self.plot.add_layout(self.arrows[name1+'_'+name2])
             scale12, baseUnit = quant1['baseUnit'].isScalablyEqualTo(quant2['baseUnit'])
             if not name1 + '_' + name2 in self.scalFactorLables:
-                self.scalFactorLables[name1 + '_' + name2]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2]), y=np.abs(y1-y2)/2+np.min([y1,y2]), text="{:.4g}".format(scale12), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX])
+                angle_deg = -1*np.arctan2(y2 - y1, x2 - x1)
+                if abs(angle_deg)>np.pi/8:
+                    angle_deg+=-np.pi/2
+                self.scalFactorLables[name1 + '_' + name2]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2]), y=np.abs(y1-y2)/2+np.min([y1,y2]), text="{:.4g}".format(scale12), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX],angle=angle_deg)
                 self.plot.add_layout(self.scalFactorLables[name1 + '_' + name2])
             else:
                 self.scalFactorLables[name1 + '_' + name2].text="{:.4g}".format(scale12)
@@ -141,7 +163,10 @@ class dsiCompGraphGen:
 
             scale21, baseUnit = quant2['baseUnit'].isScalablyEqualTo(quant1['baseUnit'])
             if not name2 + '_' + name1 in self.scalFactorLables:
-                self.scalFactorLables[name2 + '_' + name1]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2])+0.05, y=np.abs(y1-y2)/2+np.min([y1,y2])+0.05, text="{:.4g}".format(scale21), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX+1])
+                angle_deg = -1*np.arctan2(y2 - y1, x2 - x1)
+                if abs(angle_deg)>np.pi/8:
+                    angle_deg+=-np.pi/2
+                self.scalFactorLables[name2 + '_' + name1]=Label(x=np.abs(x1-x2)/2+np.min([x1,x2])+0.05, y=np.abs(y1-y2)/2+np.min([y1,y2])+0.05, text="{:.4g}".format(scale21), text_font_size="12px", text_baseline=unitToDraw['text_baseLine'], text_align=unitToDraw['text_align'],text_color=colors[colIDX+1],angle=angle_deg)
                 self.plot.add_layout(self.scalFactorLables[name2 + '_' + name1])
             else:
                 self.scalFactorLables[name2+ '_' + name1].text="{:.4g}".format(scale21)
@@ -153,8 +178,8 @@ class page():
         curdoc().title = "DSI to Latex"
         curdoc().add_root(bokehCssPTB.getStyleDiv())
         curdoc().theme = bokehCssPTB.getTheme()
-        self.dsiInput1 = dsiparserInput()
-        self.dsiInput2 = dsiparserInput()
+        self.dsiInput1 = dsiparserInput(defaultInput="\\milli\\newton\\metre")
+        self.dsiInput2 = dsiparserInput(defaultInput="\\kilo\\joule")
         self.inputs=row([self.dsiInput1.widget, self.dsiInput2.widget])
         curdoc().add_root(self.inputs)
 
@@ -169,12 +194,21 @@ class page():
     def compare(self):
         self.dsiInput1.parseInput()
         self.dsiInput2.parseInput()
-        scalfactor,baseUnit=self.dsiInput1.dsiTree.isScalablyEqualTo(self.dsiInput2.dsiTree)
-        if not math.isnan(scalfactor):
-            self.compaReresult.text = "The two units are equal up to a scaling factor of "+str(scalfactor)+" and a base unit of "+str(baseUnit)
-            self.compaReresult.css_classes=["msg-positive"]
-        else:
+        try:
+            scalfactor,baseUnit=self.dsiInput1.dsiTree.isScalablyEqualTo(self.dsiInput2.dsiTree)
+            if not math.isnan(scalfactor):
+                self.compaReresult.text = "The two units are equal up to a scaling factor of "+str(scalfactor)+" and a base unit of "+str(baseUnit)
+                self.compaReresult.css_classes=["msg-positive"]
+            else:
+                self.compaReresult.text = "The two units are not equal"
+                self.compaReresult.css_classes = ["msg-negative"]
+            if self.dsiInput1.valideUnit and self.dsiInput2.valideUnit:
+                self.dsiCompGraphGen.reDraw(self.dsiInput1.dsiTree,self.dsiInput2.dsiTree)
+            else:
+                self.dsiCompGraphGen.flush()
+        except AttributeError as Ae:
+            warnings.warn("AttributeError: "+str(Ae))
             self.compaReresult.text = "The two units are not equal"
             self.compaReresult.css_classes = ["msg-negative"]
-        self.dsiCompGraphGen.reDraw(self.dsiInput1.dsiTree,self.dsiInput2.dsiTree)
+            self.dsiCompGraphGen.flush()
 thisPage = page()
-- 
GitLab