Skip to content

Commit 8a42b52

Browse files
authored
Merge pull request #23 from Axiomatic-AI/kevin_small_changes
Kevin small changes
2 parents aaac16d + cd90de2 commit 8a42b52

File tree

2 files changed

+480
-135
lines changed

2 files changed

+480
-135
lines changed

src/axiomatic/axtract.py

Lines changed: 56 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import re
88
from dataclasses import dataclass, asdict
99

10-
1110
@dataclass
1211
class RequirementUserInput:
1312
requirement_name: str
@@ -411,139 +410,6 @@ def save_requirements(_):
411410
return result
412411

413412

414-
def display_results(equations_dict):
415-
"""Display equation validation results in a clear, organized format."""
416-
results = equations_dict.get("results", {})
417-
418-
# Helper function to convert Eq(LHS,RHS) to LHS=RHS format
419-
def format_equation(latex_eq):
420-
# Remove 'Eq(' from start and ')' from end
421-
inner = latex_eq[3:-1]
422-
# Split by comma and join with equals sign
423-
lhs, rhs = inner.split(',', 1)
424-
return f"{lhs} = {rhs}"
425-
426-
# Split results into matching and non-matching equations
427-
matching = []
428-
non_matching = []
429-
430-
for key, value in results.items():
431-
equation_data = {
432-
'latex': format_equation(value.get('latex_equation')),
433-
'lhs': value.get('lhs'),
434-
'rhs': value.get('rhs'),
435-
'diff': abs(value.get('lhs', 0) - value.get('rhs', 0)),
436-
'percent_diff': abs(value.get('lhs', 0) - value.get('rhs', 0)) / max(abs(value.get('rhs', 0)), 1e-10) * 100
437-
}
438-
if value.get('match'):
439-
matching.append(equation_data)
440-
else:
441-
non_matching.append(equation_data)
442-
443-
# Display summary header
444-
total = len(results)
445-
display(HTML(
446-
f'<h3 style="font-family:Arial">Equation Validation Summary</h3>'
447-
f'<p style="font-family:Arial">Total equations checked: {total}<br>'
448-
f'✅ Matching equations: {len(matching)}<br>'
449-
f'❌ Non-matching equations: {len(non_matching)}</p>'
450-
))
451-
452-
# Display non-matching equations first (if any)
453-
if non_matching:
454-
display(HTML(
455-
'<div style="background-color:#fff0f0; padding:10px; border-radius:5px; margin:10px 0;">'
456-
'<h4 style="color:#cc0000; font-family:Arial">⚠️ Equations Not Satisfied:</h4>'
457-
))
458-
459-
for eq in non_matching:
460-
display(Math(eq['latex']))
461-
display(HTML(
462-
f'<div style="font-family:monospace; margin-left:20px; margin-bottom:15px">'
463-
f'Left side = {eq["lhs"]:.6g}<br>'
464-
f'Right side = {eq["rhs"]:.6g}<br>'
465-
f'Difference = {eq["diff"]:.6g}<br>'
466-
f'Percent difference = {eq["percent_diff"]:.2f}%'
467-
'</div>'
468-
))
469-
470-
display(HTML('</div>'))
471-
472-
# Display matching equations (if any)
473-
if matching:
474-
display(HTML(
475-
'<div style="background-color:#f0fff0; padding:10px; border-radius:5px; margin:10px 0;">'
476-
'<h4 style="color:#006600; font-family:Arial">✅ Satisfied Equations:</h4>'
477-
))
478-
479-
for eq in matching:
480-
display(Math(eq['latex']))
481-
display(HTML(
482-
f'<div style="font-family:monospace; margin-left:20px; margin-bottom:15px">'
483-
f'Value = {eq["lhs"]:.6g}'
484-
'</div>'
485-
))
486-
487-
display(HTML('</div>'))
488-
489-
def get_eq_hypergraph(api_results, requirements, with_printing=True):
490-
491-
list_api_requirements = [asdict(req) for req in requirements]
492-
493-
# Disable external LaTeX rendering, using matplotlib's mathtext instead
494-
plt.rcParams["text.usetex"] = False
495-
plt.rcParams["mathtext.fontset"] = "stix"
496-
plt.rcParams["font.family"] = "serif"
497-
498-
api_results = _add_used_vars_to_results(api_results, list_api_requirements)
499-
500-
# Prepare the data for HyperNetX visualization
501-
hyperedges = {}
502-
for eq, details in api_results["results"].items():
503-
hyperedges[
504-
_get_latex_string_format(details["latex_equation"])] = details[
505-
"used_vars"
506-
]
507-
508-
# Create the hypergraph using HyperNetX
509-
H = hnx.Hypergraph(hyperedges)
510-
511-
# Plot the hypergraph with enhanced clarity
512-
plt.figure(figsize=(16, 12))
513-
514-
# Draw the hypergraph with node and edge labels
515-
hnx.draw(
516-
H,
517-
with_edge_labels=True,
518-
edge_labels_on_edge=False,
519-
node_labels_kwargs={"fontsize": 14},
520-
edge_labels_kwargs={"fontsize": 14},
521-
layout_kwargs={"seed": 42, "scale": 2.5},
522-
)
523-
524-
node_labels = list(H.nodes)
525-
symbol_explanations = _get_node_names_for_node_lables(
526-
node_labels,
527-
list_api_requirements)
528-
529-
# Adding the symbol explanations as a legend
530-
explanation_text = "\n".join(
531-
[f"${symbol}$: {desc}" for symbol, desc in symbol_explanations]
532-
)
533-
plt.annotate(
534-
explanation_text,
535-
xy=(1.05, 0.5),
536-
xycoords="axes fraction",
537-
fontsize=14,
538-
verticalalignment="center",
539-
)
540-
plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
541-
if with_printing:
542-
plt.show()
543-
return H
544-
else:
545-
return H
546-
547413

548414
def _get_node_names_for_node_lables(node_labels, api_requirements):
549415

@@ -741,4 +607,59 @@ def format_equation(latex_eq):
741607
plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
742608
plt.show()
743609

744-
return None
610+
return None
611+
612+
613+
def get_numerical_values(ax_client, path, constants_of_interest):
614+
with open(path, "rb") as f:
615+
file = f.read()
616+
617+
constants = ax_client.document.constants(file=file, constants=constants_of_interest).constants
618+
619+
# Create a dictionary to store processed values
620+
processed_values = {}
621+
622+
# Process each constant name from the constants dictionary
623+
for constant_name in constants:
624+
value_str = constants[constant_name] # Get the value directly from the dictionary
625+
626+
if value_str is None:
627+
# Handle None values
628+
processed_values[constant_name] = {
629+
"Value": 0.0,
630+
"Units": "unknown"
631+
}
632+
elif 'F/' in value_str:
633+
# Handle F-number values
634+
f_number = float(value_str.split('/')[-1])
635+
processed_values[constant_name] = {
636+
"Value": f_number,
637+
"Units": "dimensionless"
638+
}
639+
else:
640+
# Handle normal values with units
641+
# Split on the last space to separate value and unit
642+
parts = value_str.rsplit(' ', 1)
643+
if len(parts) == 2:
644+
value, unit = parts
645+
processed_values[constant_name] = {
646+
"Value": float(value),
647+
"Units": unit
648+
}
649+
else:
650+
# If no unit is found
651+
processed_values[constant_name] = {
652+
"Value": float(parts[0]),
653+
"Units": "unknown"
654+
}
655+
656+
# Save as custom preset
657+
filename = os.path.basename(path)
658+
with open("./custom_presets.json", "r+") as f:
659+
presets = json.load(f)
660+
presets[filename] = processed_values
661+
f.seek(0)
662+
json.dump(presets, f, indent=2)
663+
f.truncate()
664+
665+
return processed_values

0 commit comments

Comments
 (0)