2121from .string_distance import StringDistance
2222
2323
24- class CharacterInsDelInterface :
24+ def default_insertion_cost (char ):
25+ return 1.0
2526
26- def deletion_cost (self , c ):
27- raise NotImplementedError ()
2827
29- def insertion_cost ( self , c ):
30- raise NotImplementedError ()
28+ def default_deletion_cost ( char ):
29+ return 1.0
3130
3231
33- class CharacterSubstitutionInterface :
34-
35- def cost (self , c0 , c1 ):
36- raise NotImplementedError ()
32+ def default_substitution_cost (char_a , char_b ):
33+ return 1.0
3734
3835
3936class WeightedLevenshtein (StringDistance ):
4037
41- def __init__ (self , character_substitution , character_ins_del = None ):
42- self .character_ins_del = character_ins_del
43- if character_substitution is None :
44- raise TypeError ("Argument character_substitution is NoneType." )
45- self .character_substitution = character_substitution
38+ def __init__ (self ,
39+ substitution_cost_fn = default_substitution_cost ,
40+ insertion_cost_fn = default_insertion_cost ,
41+ deletion_cost_fn = default_deletion_cost ,
42+ ):
43+ self .substitution_cost_fn = substitution_cost_fn
44+ self .insertion_cost_fn = insertion_cost_fn
45+ self .deletion_cost_fn = deletion_cost_fn
4646
4747 def distance (self , s0 , s1 ):
4848 if s0 is None :
@@ -60,30 +60,20 @@ def distance(self, s0, s1):
6060
6161 v0 [0 ] = 0
6262 for i in range (1 , len (v0 )):
63- v0 [i ] = v0 [i - 1 ] + self ._insertion_cost (s1 [i - 1 ])
63+ v0 [i ] = v0 [i - 1 ] + self .insertion_cost_fn (s1 [i - 1 ])
6464
6565 for i in range (len (s0 )):
6666 s1i = s0 [i ]
67- deletion_cost = self ._deletion_cost (s1i )
67+ deletion_cost = self .deletion_cost_fn (s1i )
6868 v1 [0 ] = v0 [0 ] + deletion_cost
6969
7070 for j in range (len (s1 )):
7171 s2j = s1 [j ]
7272 cost = 0
7373 if s1i != s2j :
74- cost = self .character_substitution . cost (s1i , s2j )
75- insertion_cost = self ._insertion_cost (s2j )
74+ cost = self .substitution_cost_fn (s1i , s2j )
75+ insertion_cost = self .insertion_cost_fn (s2j )
7676 v1 [j + 1 ] = min (v1 [j ] + insertion_cost , v0 [j + 1 ] + deletion_cost , v0 [j ] + cost )
7777 v0 , v1 = v1 , v0
7878
7979 return v0 [len (s1 )]
80-
81- def _insertion_cost (self , c ):
82- if self .character_ins_del is None :
83- return 1.0
84- return self .character_ins_del .insertion_cost (c )
85-
86- def _deletion_cost (self , c ):
87- if self .character_ins_del is None :
88- return 1.0
89- return self .character_ins_del .deletion_cost (c )
0 commit comments