Skip to content

Commit 69f77e9

Browse files
authored
Merge pull request #5 from Axiomatic-AI/new_axtract
axtract.py completely changed
2 parents 905f382 + 1d4e9df commit 69f77e9

File tree

1 file changed

+180
-42
lines changed

1 file changed

+180
-42
lines changed

src/axiomatic/axtract.py

Lines changed: 180 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import ipywidgets as widgets # type: ignore
2-
from IPython.display import display # type: ignore
2+
from IPython.display import display, Math, HTML # type: ignore
33
from dataclasses import dataclass, field
4+
import hypernetx as hnx # type: ignore
5+
import matplotlib.pyplot as plt
6+
import re
47

58
OPTION_LIST = {
69
"Select a template": [],
@@ -18,7 +21,7 @@
1821
"Pixel size (multispectral)",
1922
"Swath width",
2023
],
21-
"PAYLOAD": [
24+
"PAYLOAD": [
2225
"Resolution (panchromatic)",
2326
"Ground sampling distance (panchromatic)",
2427
"Resolution (multispectral)",
@@ -117,15 +120,14 @@ def requirements_from_table(results, variable_dict):
117120
name = key
118121
numerical_value = value["Value"]
119122
unit = value["Units"]
120-
tolerance = value["Tolerance"]
121123

122124
requirements.append(
123125
Requirement(
124126
requirement_name=name,
125127
latex_symbol=latex_symbol,
126128
value=numerical_value,
127129
units=unit,
128-
tolerance=tolerance,
130+
tolerance=0.0,
129131
)
130132
)
131133

@@ -178,9 +180,15 @@ def display_table(change):
178180

179181
if selected_option in preset_options_dict:
180182
rows = preset_options_dict[selected_option]
181-
max_name_length = max(len(name) for name in rows)
182-
# Update the name_label_width based on the longest row name
183-
name_label_width[0] = f"{max_name_length + 2}ch"
183+
184+
if selected_option != "Select a template":
185+
max_name_length = max(len(name) for name in rows)
186+
# Update the name_label_width based on the longest row name
187+
name_label_width[0] = f"{max_name_length + 2}ch"
188+
else:
189+
max_name_length = 40
190+
# Update the name_label_width based on the longest row name
191+
name_label_width[0] = f"{max_name_length + 2}ch"
184192

185193
# Add Headers
186194
header_labels = [
@@ -194,16 +202,6 @@ def display_table(change):
194202
layout=widgets.Layout(width="150px"),
195203
style={'font_weight': 'bold'}
196204
),
197-
widgets.Label(
198-
value="Tolerance",
199-
layout=widgets.Layout(width="150px"),
200-
style={'font_weight': 'bold'}
201-
),
202-
widgets.Label(
203-
value="Accuracy",
204-
layout=widgets.Layout(width="150px"),
205-
style={'font_weight': 'bold'}
206-
),
207205
widgets.Label(
208206
value="Units",
209207
layout=widgets.Layout(width="150px"),
@@ -216,7 +214,6 @@ def display_table(change):
216214
header.layout = widgets.Layout(
217215
border='1px solid black',
218216
padding='5px',
219-
background_color='#f0f0f0'
220217
)
221218

222219
# Add the header to the rows_output VBox
@@ -244,28 +241,19 @@ def display_table(change):
244241

245242
# Create input widgets
246243
value_text = widgets.FloatText(
247-
placeholder="Value",
248244
value=default_value,
249245
layout=widgets.Layout(width="150px"),
250246
)
251-
tolerance_text = widgets.FloatText(
252-
placeholder="Tolerance", layout=widgets.Layout(width="150px")
253-
)
254-
accuracy_text = widgets.FloatText(
255-
placeholder="Accuracy", layout=widgets.Layout(width="150px")
256-
)
257247
units_text = widgets.Text(
258-
placeholder="Units", layout=widgets.Layout(width="150px"),
259-
value = default_unit
248+
layout=widgets.Layout(width="150px"),
249+
value=default_unit
260250
)
261251

262252
# Combine widgets into a horizontal box
263253
row = widgets.HBox(
264254
[
265255
name_label,
266256
value_text,
267-
tolerance_text,
268-
accuracy_text,
269257
units_text,
270258
]
271259
)
@@ -291,16 +279,12 @@ def submit_values(_):
291279
if key.startswith("req_"):
292280
updated_values[variable] = {
293281
"Value": widget.children[1].value,
294-
"Tolerance": widget.children[2].value,
295-
"Accuracy": widget.children[3].value,
296-
"Units": widget.children[4].value,
282+
"Units": widget.children[2].value,
297283
}
298284
else:
299285
updated_values[key] = {
300286
"Value": widget.children[1].value,
301-
"Tolerance": widget.children[2].value,
302-
"Accuracy": widget.children[3].value,
303-
"Units": widget.children[4].value,
287+
"Units": widget.children[2].value,
304288
}
305289

306290
result["values"] = updated_values
@@ -327,18 +311,13 @@ def add_req(_):
327311
placeholder="Value",
328312
layout=widgets.Layout(width="150px"),
329313
)
330-
tolerance_text = widgets.FloatText(
331-
placeholder="Tolerance", layout=widgets.Layout(width="150px")
332-
)
333-
accuracy_text = widgets.FloatText(
334-
placeholder="Accuracy", layout=widgets.Layout(width="150px")
335-
)
314+
336315
units_text = widgets.Text(
337316
placeholder="Units", layout=widgets.Layout(width="150px")
338317
)
339318

340319
new_row = widgets.HBox(
341-
[variable_dropdown, value_text, tolerance_text, accuracy_text, units_text]
320+
[variable_dropdown, value_text, units_text]
342321
)
343322

344323
rows_output.children += (new_row,)
@@ -354,3 +333,162 @@ def add_req(_):
354333
display(buttons_box)
355334

356335
return result
336+
337+
338+
def display_formatted_answers(equations_dict):
339+
"""
340+
Display LaTeX formatted equations and numerical results from a nested
341+
dictionary structure in Jupyter Notebook.
342+
343+
Parameters:
344+
equations_dict (dict): The dictionary containing the equations.
345+
"""
346+
results = equations_dict.get('results', {})
347+
print("We identified the following equations that are relevant to your requirements:")
348+
349+
for key, value in results.items():
350+
latex_equation = value.get('latex_equation')
351+
lhs = value.get('lhs')
352+
rhs = value.get('rhs')
353+
match = value.get('match')
354+
if latex_equation:
355+
display(Math(latex_equation))
356+
print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}")
357+
if match:
358+
print("Provided requirements fulfill this mathematical relation")
359+
else:
360+
print(f"No LaTeX equation found for {key}")
361+
362+
363+
def display_results(equations_dict):
364+
365+
results = equations_dict.get('results', {})
366+
not_match_counter = 0
367+
368+
for key, value in results.items():
369+
match = value.get('match')
370+
latex_equation = value.get('latex_equation')
371+
lhs = value.get('lhs')
372+
rhs = value.get('rhs')
373+
if not match:
374+
not_match_counter += 1
375+
display(HTML(
376+
'<p style="color:red; '
377+
'font-weight:bold; '
378+
'font-family:\'Times New Roman\'; '
379+
'font-size:16px;">'
380+
'Provided requirements DO NOT fulfill the following mathematical relation:'
381+
'</p>'
382+
))
383+
display(Math(latex_equation))
384+
print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}")
385+
if not_match_counter == 0:
386+
display(HTML(
387+
'<p style="color:green; '
388+
'font-weight:bold; '
389+
'font-family:\'Times New Roman\'; '
390+
'font-size:16px;">'
391+
'Requirements you provided do not cause any conflicts'
392+
'</p>'
393+
))
394+
395+
396+
def _get_latex_string_format(input_string):
397+
"""
398+
Properly formats LaTeX strings for matplotlib when text.usetex is False.
399+
No escaping needed since mathtext handles backslashes properly.
400+
"""
401+
return f"${input_string}$" # No backslash escaping required
402+
403+
404+
def _get_requirements_set(requirements):
405+
variable_set = set()
406+
for req in requirements:
407+
variable_set.add(req['latex_symbol'])
408+
409+
return variable_set
410+
411+
412+
def _find_vars_in_eq(equation, variable_set):
413+
patterns = [re.escape(var) for var in variable_set]
414+
combined_pattern = r'|'.join(patterns)
415+
matches = re.findall(combined_pattern, equation)
416+
return {fr"${match}$" for match in matches}
417+
418+
419+
def _add_used_vars_to_results(api_results, api_requirements):
420+
requirements = _get_requirements_set(api_requirements)
421+
422+
for key, value in api_results['results'].items():
423+
latex_equation = value.get('latex_equation')
424+
# print(latex_equation)
425+
if latex_equation:
426+
used_vars = _find_vars_in_eq(latex_equation, requirements)
427+
api_results['results'][key]['used_vars'] = used_vars
428+
429+
return api_results
430+
431+
432+
def get_eq_hypergraph(api_results, api_requirements):
433+
# Disable external LaTeX rendering, using matplotlib's mathtext instead
434+
plt.rcParams['text.usetex'] = False
435+
plt.rcParams['mathtext.fontset'] = 'stix'
436+
plt.rcParams['font.family'] = 'serif'
437+
438+
api_results = _add_used_vars_to_results(api_results, api_requirements)
439+
440+
# Prepare the data for HyperNetX visualization
441+
hyperedges = {}
442+
for eq, details in api_results["results"].items():
443+
hyperedges[_get_latex_string_format(
444+
details["latex_equation"])] = details["used_vars"]
445+
446+
# Create the hypergraph using HyperNetX
447+
H = hnx.Hypergraph(hyperedges)
448+
449+
# Plot the hypergraph with enhanced clarity
450+
plt.figure(figsize=(16, 12))
451+
452+
# Draw the hypergraph with node and edge labels
453+
hnx.draw(
454+
H,
455+
with_edge_labels=True,
456+
edge_labels_on_edge=False,
457+
node_labels_kwargs={'fontsize': 14},
458+
edge_labels_kwargs={'fontsize': 14},
459+
layout_kwargs={'seed': 42, 'scale': 2.5}
460+
)
461+
462+
node_labels = list(H.nodes)
463+
symbol_explanations = _get_node_names_for_node_lables(node_labels, api_requirements)
464+
465+
# Adding the symbol explanations as a legend
466+
explanation_text = "\n".join([f"${symbol}$: {desc}" for symbol, desc in symbol_explanations])
467+
plt.annotate(
468+
explanation_text,
469+
xy=(1.05, 0.5),
470+
xycoords='axes fraction',
471+
fontsize=14,
472+
verticalalignment='center'
473+
)
474+
475+
plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
476+
plt.show()
477+
478+
479+
def _get_node_names_for_node_lables(node_labels, api_requirements):
480+
481+
# Create the output list
482+
node_names = []
483+
484+
# Iterate through each symbol in S
485+
for symbol in node_labels:
486+
# Search for the matching requirement
487+
symbol = symbol.replace("$", "")
488+
for req in api_requirements:
489+
if req['latex_symbol'] == symbol:
490+
# Add the matching tuple to SS
491+
node_names.append((req["latex_symbol"], req["requirement_name"]))
492+
break # Stop searching once a match is found
493+
494+
return node_names

0 commit comments

Comments
 (0)