27 lines
1.1 KiB
Python
27 lines
1.1 KiB
Python
from sklearn.feature_extraction.text import CountVectorizer
|
|
from sklearn.linear_model import SGDClassifier
|
|
|
|
class Classifier(object):
|
|
def __init__(self, datagrabber):
|
|
self.grabber = datagrabber
|
|
self.reload()
|
|
|
|
def reload(self):
|
|
Xs, Ys = self.grabber()
|
|
|
|
self.vect = CountVectorizer(analyzer='word',ngram_range=(1,3))
|
|
self.train_vec = self.vect.fit_transform(Xs)
|
|
|
|
self.clf = SGDClassifier(loss='hinge', penalty='l2',alpha=1e-3, n_iter=500, random_state=42)
|
|
self.text_clf = self.clf.fit(self.train_vec, Ys)
|
|
|
|
def scan(self, name):
|
|
v = self.vect.transform([name])
|
|
return self.text_clf.decision_function(v)[0]
|
|
|
|
def add(self, name, state):
|
|
# implement add using partial_fit
|
|
# this would mean switching to hashing vectorizer, which means we can't reverse the model
|
|
# so for now we're just going to reload completely
|
|
self.reload()
|