LLH
2024/02/14/01:14
bd39f54
raw
history blame
1.29 kB
from sklearn.naive_bayes import *
from coding.llh.visualization.draw_line_graph import draw_line_graph
from coding.llh.visualization.draw_scatter_line_graph import draw_scatter_line_graph
from coding.llh.metrics.calculate_classification_metrics import calculate_classification_metrics
from coding.llh.metrics.calculate_regression_metrics import calculate_regression_metrics
# Naive bayes classification
def naive_bayes_classification(x_train, y_train, x_test, y_test):
info = {}
# multinomial_naive_bayes_classification_model = MultinomialNB()
Gaussian_naive_bayes_classification_model = GaussianNB()
# bernoulli_naive_bayes_classification_model = BernoulliNB()
# complement_naive_bayes_classification_model = ComplementNB()
Gaussian_naive_bayes_classification_model.fit(x_train, y_train)
y_pred = Gaussian_naive_bayes_classification_model.predict(x_test).reshape(-1, 1)
# draw_scatter_line_graph(x_test, y_pred, y_test, lr_coef, lr_intercept, ["pred", "real"], "Gaussian naive bayes classification model residual plot")
info.update(calculate_regression_metrics(y_pred, y_test, "Gaussian naive bayes classification"))
info.update(calculate_classification_metrics(y_pred, y_test, "Gaussian naive bayes classification"))
return info