-
Notifications
You must be signed in to change notification settings - Fork 19
/
perceptronF.py
65 lines (51 loc) · 2.14 KB
/
perceptronF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import matplotlib.pyplot as plt
class Perceptron:
def __init__(self, epocas=100, taxa_aprendizagem=0.05):
self.epocas = 100
self.taxa_aprendizagem = 0.05
self.pesos = None # Inicializa os pesos como None
def __step_function(self, entradas):
soma_ponderada = np.dot(entradas, self.pesos[1:]) + self.pesos[0]
return 1 if soma_ponderada > 0 else 0
def testar(self, entradas):
if self.pesos is None:
raise ValueError("Modelo não treinado. Treine o modelo antes de testar.")
return self.__step_function(entradas)
def treinar(self, X, y):
# Inicializa os pesos com valores aleatórios
self.pesos = np.random.uniform(-1, 1, X.shape[1] + 1)
mse_ = []
for epoca in range(self.epocas):
previsoes = []
erros = 0
for entradas, esperado in zip(X, y):
previsao = self.__step_function(entradas)
erro = esperado - previsao
# Atualiza os pesos
self.pesos[1:] += self.taxa_aprendizagem * erro * entradas
self.pesos[0] += self.taxa_aprendizagem * erro
erros += abs(erro)
previsoes.append(previsao)
# Calcula o MSE após todas as previsões
MSE = np.square(np.subtract(y, previsoes)).mean()
mse_.append(MSE)
print(f'Época -> {epoca + 1}, MSE -> {MSE}')
# Para se não houver erros
if erros == 0:
print(f'\nTreinamento finalizado com os seguintes \npesos -> {self.pesos}')
break
# Plota o gráfico de MSE após o treinamento
plt.plot(range(1, len(mse_) + 1), mse_, marker='o')
plt.title('Gráfico de Erros por Época de Treinamento')
plt.xlabel('Época')
plt.ylabel('MSE')
plt.show()
# Exemplo de uso
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([0, 0, 0, 1])
p = Perceptron()
p.treinar(X, y)
print('\nResultados')
for i in range(X.shape[0]):
print(f'Entrada -> {X[i]}, Esperado -> {y[i]}, Resultado -> {p.testar(X[i])}')