设定 {n} 匹配长度为n的任意子串;{m, n}匹配长度为m到n的任意子串(m < n)。
对于带有通配符{n}的模式串:
{1}ATC{2}TC{1}ATC
将其按照通配符分割成三个模式子串:ATC、TC、ATC。
然后使用这三个模式子串构造AC自动机。
每个模式子串的末尾字符记录该字符(及其fail指针指向的字符)在原模式串中的位置。

其中,节点3是第一个模式串的结尾字符(在原模式串中的位置是4,从1开始计数);第三个模式串的结尾字符(在原模式串中的位置是12);它的fail指向节点5(在原模式串中的位置是8)。
节点5是第二个模式串的结尾字符(在原模式串中的位置是8)。
在进行模式匹配的时候,在某个模式子串匹配成功之后,则根据其记录的位置信息,来判断起始点所在的位置,然后在这个位置上加1。最后就可以得到查询字符串在每个位置上的count值。
比如,当查询字符串是“ACGATCTCTCGATC”的时候,在匹配到第5个字符的时候(绿色,从0开始计数),自动机走到节点3,那么可以推出起点可能在查询字符串的第2个位置,将该位置的count值加1:
0 0 1 0 0 0 0 0 0 0 0 0 0 0
在匹配到第7个字符串的时候,自动机走到节点5,那么可以推出起点的位置不在查询字符串上,此时,不修改count值:
0 0 1 0 0 0 0 0 0 0 0 0 0 0
在匹配到第9个字符(蓝色)的时候,自动机走到节点5,那么起点可能在查询字符串第2个位置,将该位置的count值加1:
0 0 2 0 0 0 0 0 0 0 0 0 0 0
在匹配到第13个字符(黄色)的时候,自动机走到节点3,那么,可以推出起点可能在查询字符串的第2个、第6个、第10个位置,将这些位置上的count值加1:
0 0 3 0 0 0 1 0 0 0 1 0 0 0
最后,可以看到第3个位置的count值,和模式子串的数量相同,因此,得出通配符模式串匹配成功,在查询字符串上的起点是3。
# coding: utf8
import re
# 所有异常类的基类
class FilterSensitiveWordsException(BaseException):
pass
# AC自动机已经构建完毕时,添加敏感词或再次构建时,引发该异常
class ACAlreadyBuiltException(FilterSensitiveWordsException):
pass
# 添加的敏感词和过滤的字符串必须是unicode的,否则抛出该异常
class InvalidEncodingException(FilterSensitiveWordsException):
pass
# AC自动机尚未构建时,使用它过滤敏感词,引发该异常
class ACNotYetBuiltException(FilterSensitiveWordsException):
pass
# 模式串格式不正确时,抛出该异常
class InvalidPatternException(FilterSensitiveWordsException):
pass
# Trie树节点
class _Node(object):
def __init__(self):
self._subnodes = {}
self._is_end = False
self._fail = None
def set_default(self, character):
if character not in self._subnodes:
self._subnodes[character] = _Node()
return self._subnodes[character]
def get_subnode(self, character):
return self._subnodes.get(character, None)
def mark_as_end(self):
self._is_end = True
@property
def is_end(self):
return self._is_end
def iter_subnodes(self):
for character, subnode in self._subnodes.iteritems():
yield character, subnode
def set_fail(self, node):
self._fail = node
def get_fail(self):
return self._fail
# AC自动机实现
class FilterSensitiveWords(object):
def __init__(self):
self._is_built = False
self._root = _Node()
self._node_to_pattern = {}
def add_sensitive_word(self, sensitive_word, pattern, offsets, total_count):
"""向trie树增加敏感词"""
if self._is_built:
raise ACAlreadyBuiltException
if not isinstance(sensitive_word, unicode):
raise InvalidEncodingException
tmp = self._root
for character in sensitive_word:
tmp = tmp.set_default(character)
tmp.mark_as_end()
d = self._node_to_pattern.setdefault(tmp, {})
d[pattern] = {
"offsets": offsets,
"total_count": total_count}
def build(self):
"""生成fail指针"""
if self._is_built:
return
self._is_built = True
# 根节点的fail指针是null
self._root.set_fail(None)
queue = [self._root]
while queue:
node = queue.pop(0)
for character, subnode in node.iter_subnodes():
queue.append(subnode)
# 根节点的孩子节点的fail指针都指向根节点
if node is self._root:
subnode.set_fail(self._root)
continue
# f是node的fail指针指向的节点
f = node.get_fail()
while f is not None:
q = f.get_subnode(character)
if q is not None:
subnode.set_fail(q)
break
f = f.get_fail()
else:
# 指向根节点
subnode.set_fail(self._root)
def _get_output(self, p, ind, matching_patterns, string_length):
outputs = {}
while p is not None:
if not p.is_end:
p = p.get_fail()
continue
for pattern, info in self._node_to_pattern[p].iteritems():
if pattern not in matching_patterns:
matching_patterns[pattern] = [0] * string_length
# 将可能的起始位置都增加1
for offset in info["offsets"]:
for pos in range(offset[0], offset[1]+1):
start = ind - pos + 1
if start >= string_length or start < 0:
continue
matching_patterns[pattern][start] += 1
if matching_patterns[pattern][start] >= info["total_count"]:
outputs.setdefault(start, []).append([pattern, ind])
# 只返回不重叠的匹配
del matching_patterns[pattern]
break
p = p.get_fail()
return outputs
def _merge(self, d1, d2):
for k, v in d2.iteritems():
if k not in d1:
d1[k] = v
continue
d1[k].extend(d2[k])
def filter(self, string):
if not isinstance(string, unicode):
raise InvalidEncodingException
if not self._is_built:
raise ACNotYetBuiltException
matching_patterns = {}
string_length = len(string)
tmp = self._root
outputs = {}
for ind, character in enumerate(string):
while tmp is not None:
next = tmp.get_subnode(character)
if next is None:
tmp = tmp.get_fail()
continue
if next.is_end:
self._merge(outputs,
self._get_output(next, ind,
matching_patterns, string_length)
)
tmp = next
break
else:
tmp = self._root
return outputs
class Pattern(object):
_reg_exp = re.compile(r"(?P\s*\{\s*(?P\d+)\s*\}\s*|"
"\s*\{\s*(?P\d+)\s*,\s*(?P\d+)\s*\}\s*)", re.U)
def __init__(self, pattern):
self._pattern = pattern
self._tokens = self._parse_pattern(self.__class__._reg_exp, pattern)
self._result, self._count = self._get_word_position(self._tokens)
def _parse_pattern(self, reg_exp, pattern):
tokens = []
while pattern:
pattern = pattern.strip()
m = reg_exp.search(pattern)
if m is None:
break
first_part, pattern = pattern.split(m.group("splitter"), 1)
if m.group("first"):
first = int(m.group("first"))
repeat = (first, first)
else:
repeat = (int(m.group("second")), int(m.group("third")))
# 对repeat进行检查
if not (0 <= repeat[0] <= repeat[1]):
raise InvalidPatternException
first_part = first_part.strip()
if first_part:
tokens.extend([first_part, repeat])
else:
tokens.extend([repeat])
if pattern:
tokens.append(pattern)
return tokens
def _get_word_position(self, tokens):
index = 0
length = len(tokens)
result = {}
base = 0, 0
count = 0
while index < length:
element = tokens[index]
if isinstance(element, tuple):
if index == length - 1:
break
next_element = tokens[index + 1]
if isinstance(next_element, tuple):
raise InvalidPatternException
next_element_length = len(next_element)
base = base[0] + element[0] + next_element_length, \
base[1] + element[1] + next_element_length
result.setdefault(next_element, []).append(base)
count = count + 1
index = index + 2
continue
base = base[0] + len(element), base[1] + len(element)
result.setdefault(element, []).append(base)
count = count + 1
index = index + 1
return result, count
def iter_words(self):
for word, offsets in self._result.iteritems():
yield word, offsets
def get_pattern(self):
return self._pattern
def get_count(self):
return self._count
if __name__ == "__main__":
def myprint(r):
print u"\033[33m过滤结果:\033[0m"
for pos, patterns in r.iteritems():
print u" 位置:%s" % pos
print " " + "\n ".join(
[u"%s, %d"%tuple(pattern) for pattern in patterns]
)
ac = FilterSensitiveWords()
for pattern_string in [
u"日{0,3}本",
u"日{0,3}本{0,3}鬼{0,3}子",
u"大{0,3}傻{0,3}叉",
u"狗娘养的"]:
print u"\033[31m敏感词:%s\033[0m" % pattern_string
pattern = Pattern(pattern_string)
for word, offsets in pattern.iter_words():
ac.add_sensitive_word(word, pattern.get_pattern(),
offsets, pattern.get_count())
ac.build()
string = u"大家都知道:日(大)本(傻)鬼(叉)子都是狗娘养的"
string = u"日日本本鬼子"
print u"\033[32m过滤语句:%s\033[0m" % string
myprint(ac.filter(string))