11import ipywidgets as widgets # type: ignore
2- from IPython .display import display # type: ignore
2+ from IPython .display import display , Math , HTML # type: ignore
33from dataclasses import dataclass , field
4+ import hypernetx as hnx # type: ignore
5+ import matplotlib .pyplot as plt
6+ import re
47
58OPTION_LIST = {
69 "Select a template" : [],
1821 "Pixel size (multispectral)" ,
1922 "Swath width" ,
2023 ],
21- "PAYLOAD" : [
24+ "PAYLOAD" : [
2225 "Resolution (panchromatic)" ,
2326 "Ground sampling distance (panchromatic)" ,
2427 "Resolution (multispectral)" ,
@@ -117,15 +120,14 @@ def requirements_from_table(results, variable_dict):
117120 name = key
118121 numerical_value = value ["Value" ]
119122 unit = value ["Units" ]
120- tolerance = value ["Tolerance" ]
121123
122124 requirements .append (
123125 Requirement (
124126 requirement_name = name ,
125127 latex_symbol = latex_symbol ,
126128 value = numerical_value ,
127129 units = unit ,
128- tolerance = tolerance ,
130+ tolerance = 0.0 ,
129131 )
130132 )
131133
@@ -178,9 +180,15 @@ def display_table(change):
178180
179181 if selected_option in preset_options_dict :
180182 rows = preset_options_dict [selected_option ]
181- max_name_length = max (len (name ) for name in rows )
182- # Update the name_label_width based on the longest row name
183- name_label_width [0 ] = f"{ max_name_length + 2 } ch"
183+
184+ if selected_option != "Select a template" :
185+ max_name_length = max (len (name ) for name in rows )
186+ # Update the name_label_width based on the longest row name
187+ name_label_width [0 ] = f"{ max_name_length + 2 } ch"
188+ else :
189+ max_name_length = 40
190+ # Update the name_label_width based on the longest row name
191+ name_label_width [0 ] = f"{ max_name_length + 2 } ch"
184192
185193 # Add Headers
186194 header_labels = [
@@ -194,16 +202,6 @@ def display_table(change):
194202 layout = widgets .Layout (width = "150px" ),
195203 style = {'font_weight' : 'bold' }
196204 ),
197- widgets .Label (
198- value = "Tolerance" ,
199- layout = widgets .Layout (width = "150px" ),
200- style = {'font_weight' : 'bold' }
201- ),
202- widgets .Label (
203- value = "Accuracy" ,
204- layout = widgets .Layout (width = "150px" ),
205- style = {'font_weight' : 'bold' }
206- ),
207205 widgets .Label (
208206 value = "Units" ,
209207 layout = widgets .Layout (width = "150px" ),
@@ -216,7 +214,6 @@ def display_table(change):
216214 header .layout = widgets .Layout (
217215 border = '1px solid black' ,
218216 padding = '5px' ,
219- background_color = '#f0f0f0'
220217 )
221218
222219 # Add the header to the rows_output VBox
@@ -244,28 +241,19 @@ def display_table(change):
244241
245242 # Create input widgets
246243 value_text = widgets .FloatText (
247- placeholder = "Value" ,
248244 value = default_value ,
249245 layout = widgets .Layout (width = "150px" ),
250246 )
251- tolerance_text = widgets .FloatText (
252- placeholder = "Tolerance" , layout = widgets .Layout (width = "150px" )
253- )
254- accuracy_text = widgets .FloatText (
255- placeholder = "Accuracy" , layout = widgets .Layout (width = "150px" )
256- )
257247 units_text = widgets .Text (
258- placeholder = "Units" , layout = widgets .Layout (width = "150px" ),
259- value = default_unit
248+ layout = widgets .Layout (width = "150px" ),
249+ value = default_unit
260250 )
261251
262252 # Combine widgets into a horizontal box
263253 row = widgets .HBox (
264254 [
265255 name_label ,
266256 value_text ,
267- tolerance_text ,
268- accuracy_text ,
269257 units_text ,
270258 ]
271259 )
@@ -291,16 +279,12 @@ def submit_values(_):
291279 if key .startswith ("req_" ):
292280 updated_values [variable ] = {
293281 "Value" : widget .children [1 ].value ,
294- "Tolerance" : widget .children [2 ].value ,
295- "Accuracy" : widget .children [3 ].value ,
296- "Units" : widget .children [4 ].value ,
282+ "Units" : widget .children [2 ].value ,
297283 }
298284 else :
299285 updated_values [key ] = {
300286 "Value" : widget .children [1 ].value ,
301- "Tolerance" : widget .children [2 ].value ,
302- "Accuracy" : widget .children [3 ].value ,
303- "Units" : widget .children [4 ].value ,
287+ "Units" : widget .children [2 ].value ,
304288 }
305289
306290 result ["values" ] = updated_values
@@ -327,18 +311,13 @@ def add_req(_):
327311 placeholder = "Value" ,
328312 layout = widgets .Layout (width = "150px" ),
329313 )
330- tolerance_text = widgets .FloatText (
331- placeholder = "Tolerance" , layout = widgets .Layout (width = "150px" )
332- )
333- accuracy_text = widgets .FloatText (
334- placeholder = "Accuracy" , layout = widgets .Layout (width = "150px" )
335- )
314+
336315 units_text = widgets .Text (
337316 placeholder = "Units" , layout = widgets .Layout (width = "150px" )
338317 )
339318
340319 new_row = widgets .HBox (
341- [variable_dropdown , value_text , tolerance_text , accuracy_text , units_text ]
320+ [variable_dropdown , value_text , units_text ]
342321 )
343322
344323 rows_output .children += (new_row ,)
@@ -354,3 +333,162 @@ def add_req(_):
354333 display (buttons_box )
355334
356335 return result
336+
337+
338+ def display_formatted_answers (equations_dict ):
339+ """
340+ Display LaTeX formatted equations and numerical results from a nested
341+ dictionary structure in Jupyter Notebook.
342+
343+ Parameters:
344+ equations_dict (dict): The dictionary containing the equations.
345+ """
346+ results = equations_dict .get ('results' , {})
347+ print ("We identified the following equations that are relevant to your requirements:" )
348+
349+ for key , value in results .items ():
350+ latex_equation = value .get ('latex_equation' )
351+ lhs = value .get ('lhs' )
352+ rhs = value .get ('rhs' )
353+ match = value .get ('match' )
354+ if latex_equation :
355+ display (Math (latex_equation ))
356+ print (f"For provided values:\n left hand side = { lhs } \n right hand side = { rhs } " )
357+ if match :
358+ print ("Provided requirements fulfill this mathematical relation" )
359+ else :
360+ print (f"No LaTeX equation found for { key } " )
361+
362+
363+ def display_results (equations_dict ):
364+
365+ results = equations_dict .get ('results' , {})
366+ not_match_counter = 0
367+
368+ for key , value in results .items ():
369+ match = value .get ('match' )
370+ latex_equation = value .get ('latex_equation' )
371+ lhs = value .get ('lhs' )
372+ rhs = value .get ('rhs' )
373+ if not match :
374+ not_match_counter += 1
375+ display (HTML (
376+ '<p style="color:red; '
377+ 'font-weight:bold; '
378+ 'font-family:\' Times New Roman\' ; '
379+ 'font-size:16px;">'
380+ 'Provided requirements DO NOT fulfill the following mathematical relation:'
381+ '</p>'
382+ ))
383+ display (Math (latex_equation ))
384+ print (f"For provided values:\n left hand side = { lhs } \n right hand side = { rhs } " )
385+ if not_match_counter == 0 :
386+ display (HTML (
387+ '<p style="color:green; '
388+ 'font-weight:bold; '
389+ 'font-family:\' Times New Roman\' ; '
390+ 'font-size:16px;">'
391+ 'Requirements you provided do not cause any conflicts'
392+ '</p>'
393+ ))
394+
395+
396+ def _get_latex_string_format (input_string ):
397+ """
398+ Properly formats LaTeX strings for matplotlib when text.usetex is False.
399+ No escaping needed since mathtext handles backslashes properly.
400+ """
401+ return f"${ input_string } $" # No backslash escaping required
402+
403+
404+ def _get_requirements_set (requirements ):
405+ variable_set = set ()
406+ for req in requirements :
407+ variable_set .add (req ['latex_symbol' ])
408+
409+ return variable_set
410+
411+
412+ def _find_vars_in_eq (equation , variable_set ):
413+ patterns = [re .escape (var ) for var in variable_set ]
414+ combined_pattern = r'|' .join (patterns )
415+ matches = re .findall (combined_pattern , equation )
416+ return {fr"${ match } $" for match in matches }
417+
418+
419+ def _add_used_vars_to_results (api_results , api_requirements ):
420+ requirements = _get_requirements_set (api_requirements )
421+
422+ for key , value in api_results ['results' ].items ():
423+ latex_equation = value .get ('latex_equation' )
424+ # print(latex_equation)
425+ if latex_equation :
426+ used_vars = _find_vars_in_eq (latex_equation , requirements )
427+ api_results ['results' ][key ]['used_vars' ] = used_vars
428+
429+ return api_results
430+
431+
432+ def get_eq_hypergraph (api_results , api_requirements ):
433+ # Disable external LaTeX rendering, using matplotlib's mathtext instead
434+ plt .rcParams ['text.usetex' ] = False
435+ plt .rcParams ['mathtext.fontset' ] = 'stix'
436+ plt .rcParams ['font.family' ] = 'serif'
437+
438+ api_results = _add_used_vars_to_results (api_results , api_requirements )
439+
440+ # Prepare the data for HyperNetX visualization
441+ hyperedges = {}
442+ for eq , details in api_results ["results" ].items ():
443+ hyperedges [_get_latex_string_format (
444+ details ["latex_equation" ])] = details ["used_vars" ]
445+
446+ # Create the hypergraph using HyperNetX
447+ H = hnx .Hypergraph (hyperedges )
448+
449+ # Plot the hypergraph with enhanced clarity
450+ plt .figure (figsize = (16 , 12 ))
451+
452+ # Draw the hypergraph with node and edge labels
453+ hnx .draw (
454+ H ,
455+ with_edge_labels = True ,
456+ edge_labels_on_edge = False ,
457+ node_labels_kwargs = {'fontsize' : 14 },
458+ edge_labels_kwargs = {'fontsize' : 14 },
459+ layout_kwargs = {'seed' : 42 , 'scale' : 2.5 }
460+ )
461+
462+ node_labels = list (H .nodes )
463+ symbol_explanations = _get_node_names_for_node_lables (node_labels , api_requirements )
464+
465+ # Adding the symbol explanations as a legend
466+ explanation_text = "\n " .join ([f"${ symbol } $: { desc } " for symbol , desc in symbol_explanations ])
467+ plt .annotate (
468+ explanation_text ,
469+ xy = (1.05 , 0.5 ),
470+ xycoords = 'axes fraction' ,
471+ fontsize = 14 ,
472+ verticalalignment = 'center'
473+ )
474+
475+ plt .title (r"Enhanced Hypergraph of Equations and Variables" , fontsize = 20 )
476+ plt .show ()
477+
478+
479+ def _get_node_names_for_node_lables (node_labels , api_requirements ):
480+
481+ # Create the output list
482+ node_names = []
483+
484+ # Iterate through each symbol in S
485+ for symbol in node_labels :
486+ # Search for the matching requirement
487+ symbol = symbol .replace ("$" , "" )
488+ for req in api_requirements :
489+ if req ['latex_symbol' ] == symbol :
490+ # Add the matching tuple to SS
491+ node_names .append ((req ["latex_symbol" ], req ["requirement_name" ]))
492+ break # Stop searching once a match is found
493+
494+ return node_names
0 commit comments