terça-feira, 13 de dezembro de 2016

Visualizando uma Árvore de Decisão

import numpy as np
from sklearn.datasets import load_iris
from sklearn import tree

iris = load_iris()

test_idx = [0,50,100]

# training data
train_target = np.delete(iris.target, test_idx)            #removido as amostras 0,50,100
train_data = np.delete(iris.data,test_idx, axis = 0)    #removido as caracteristicas das amostras 0,50,100

# testing data

test_target = iris.target[test_idx]        #Amostra para testar
test_data = iris.data[test_idx]            #Dados para testar

#print test_target
#print test_data

clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)

print test_target                        #imprime a label da amostra
print clf.predict(test_data)            #com base nos dados da amostrada passada, preve qual e a label    

#viz code
from sklearn.externals.six import StringIO
import pydotplus
dot_data = StringIO()
tree.export_graphviz(clf,
                        out_file=dot_data,
                        feature_names=iris.feature_names,
                        class_names=iris.target_names,
                        filled= True, rounded=True,
                        impurity=False)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("iris.pdf")

Nenhum comentário:

Postar um comentário