import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

train_X = np.random.normal(scale=2, size=(800, 69))
train_y = np.random.randint(0, high=10, size=(800,1), dtype=np.int64)

class LogistRegression(nn.Module): #  所有网络的基类。

def __init__(self, inputs, targets, targets_size=10, learning_rate=1e-4):
super().__init__() # 在子类中调用父类的初始化方法
self._train_X = inputs
self._train_y = targets
self._train_X_size = inputs.shape[1]
self._train_y_size = targets_size
self._learning_rate = learning_rate

self._linear = nn.Linear(self._train_X_size, self._train_y_size)
self._loss_function = nn.CrossEntropyLoss()
self._optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)

def fit(self, training_epochs= 1e3, display= 1e2):
display = np.int(display)
for epoch in np.arange(np.int(training_epochs)):
targets = Variable(torch.LongTensor(self._train_y.flatten()))
outputs = self._linear(inputs) # 使用神经网络架构前向推断
self._loss = self._loss_function(outputs, targets) # 计算批次损失函数
self._loss.backward() # 误差反向传播
self._optimizer.step()

if (epoch+1) % display == 0:
print ('Epoch (%d/%d), loss:%.4f' %(epoch+1, training_epochs, self._loss.data[0]))

def pred(self, X):
outputs = self._linear(Variable(torch.FloatTensor(X)))
_, output_labels  = torch.max(outputs, 1)
return output_labels

a = LogistRegression(train_X, train_y, 10)
a.fit(1e5, 1e3)

Epoch (1000/100000), loss:2.8511
Epoch (2000/100000), loss:2.7817
Epoch (3000/100000), loss:2.7173
Epoch (4000/100000), loss:2.6575
Epoch (5000/100000), loss:2.6021
Epoch (6000/100000), loss:2.5507
Epoch (7000/100000), loss:2.5032
Epoch (8000/100000), loss:2.4591
Epoch (9000/100000), loss:2.4184
Epoch (10000/100000), loss:2.3806
Epoch (11000/100000), loss:2.3457
Epoch (12000/100000), loss:2.3133
Epoch (13000/100000), loss:2.2834
Epoch (14000/100000), loss:2.2556
Epoch (15000/100000), loss:2.2299
Epoch (16000/100000), loss:2.2061
Epoch (17000/100000), loss:2.1840
Epoch (18000/100000), loss:2.1636
Epoch (19000/100000), loss:2.1446
Epoch (20000/100000), loss:2.1270
Epoch (21000/100000), loss:2.1106
Epoch (22000/100000), loss:2.0954
Epoch (23000/100000), loss:2.0813
Epoch (24000/100000), loss:2.0681
Epoch (25000/100000), loss:2.0559
Epoch (26000/100000), loss:2.0446
Epoch (27000/100000), loss:2.0340
Epoch (28000/100000), loss:2.0241
Epoch (29000/100000), loss:2.0149
Epoch (30000/100000), loss:2.0063
Epoch (31000/100000), loss:1.9983
Epoch (32000/100000), loss:1.9908
Epoch (33000/100000), loss:1.9838
Epoch (34000/100000), loss:1.9772
Epoch (35000/100000), loss:1.9711
Epoch (36000/100000), loss:1.9653
Epoch (37000/100000), loss:1.9599
Epoch (38000/100000), loss:1.9548
Epoch (39000/100000), loss:1.9501
Epoch (40000/100000), loss:1.9456
Epoch (41000/100000), loss:1.9414
Epoch (42000/100000), loss:1.9374
Epoch (43000/100000), loss:1.9336
Epoch (44000/100000), loss:1.9301
Epoch (45000/100000), loss:1.9268
Epoch (46000/100000), loss:1.9237
Epoch (47000/100000), loss:1.9207
Epoch (48000/100000), loss:1.9179
Epoch (49000/100000), loss:1.9152
Epoch (50000/100000), loss:1.9127
Epoch (51000/100000), loss:1.9103
Epoch (52000/100000), loss:1.9080
Epoch (53000/100000), loss:1.9059
Epoch (54000/100000), loss:1.9039
Epoch (55000/100000), loss:1.9019
Epoch (56000/100000), loss:1.9001
Epoch (57000/100000), loss:1.8983
Epoch (58000/100000), loss:1.8967
Epoch (59000/100000), loss:1.8951
Epoch (60000/100000), loss:1.8936
Epoch (61000/100000), loss:1.8921
Epoch (62000/100000), loss:1.8907
Epoch (63000/100000), loss:1.8894
Epoch (64000/100000), loss:1.8882
Epoch (65000/100000), loss:1.8870
Epoch (66000/100000), loss:1.8858
Epoch (67000/100000), loss:1.8847
Epoch (68000/100000), loss:1.8837
Epoch (69000/100000), loss:1.8826
Epoch (70000/100000), loss:1.8817
Epoch (71000/100000), loss:1.8808
Epoch (72000/100000), loss:1.8799
Epoch (73000/100000), loss:1.8790
Epoch (74000/100000), loss:1.8782
Epoch (75000/100000), loss:1.8774
Epoch (76000/100000), loss:1.8767
Epoch (77000/100000), loss:1.8759
Epoch (78000/100000), loss:1.8752
Epoch (79000/100000), loss:1.8746
Epoch (80000/100000), loss:1.8739
Epoch (81000/100000), loss:1.8733
Epoch (82000/100000), loss:1.8727
Epoch (83000/100000), loss:1.8721
Epoch (84000/100000), loss:1.8716
Epoch (85000/100000), loss:1.8711
Epoch (86000/100000), loss:1.8705
Epoch (87000/100000), loss:1.8700
Epoch (88000/100000), loss:1.8696
Epoch (89000/100000), loss:1.8691
Epoch (90000/100000), loss:1.8687
Epoch (91000/100000), loss:1.8682
Epoch (92000/100000), loss:1.8678
Epoch (93000/100000), loss:1.8674
Epoch (94000/100000), loss:1.8670
Epoch (95000/100000), loss:1.8666
Epoch (96000/100000), loss:1.8663
Epoch (97000/100000), loss:1.8659
Epoch (98000/100000), loss:1.8656
Epoch (99000/100000), loss:1.8653
Epoch (100000/100000), loss:1.8649

a.pred(train_X)

Variable containing:
3
5
4
9
4
4
0
8
3
8
0
3
2
6
6
6
7
0
8
8
8
5
5
5
4
1
8
8
3
6
6
0
1
1
7
7
3
9
1
1
5
6
3
1
6
1
3
2
3
4
3
4
1
1
6
4
8
6
6
6
3
6
0
7
8
3
6
7
4
6
8
0
5
3
3
5
4
1
7
2
1
7
4
0
6
6
6
7
7
1
7
0
5
3
3
7
9
0
9
6
3
3
6
1
2
8
4
1
5
4
2
8
8
4
7
3
9
5
8
1
4
0
5
7
8
2
6
2
3
8
3
7
4
5
3
3
1
8
5
4
4
9
0
4
4
7
1
3
5
9
6
7
4
9
5
8
2
4
1
4
4
3
8
4
7
5
3
1
9
6
1
5
1
6
3
8
0
8
6
1
9
9
8
9
0
4
5
1
9
3
2
0
4
0
1
8
5
4
2
3
2
1
1
4
9
0
2
3
3
0
5
6
1
7
9
4
0
5
0
9
3
0
8
2
7
9
3
0
8
0
0
6
9
0
8
6
4
1
7
4
8
8
6
5
8
8
1
0
9
5
4
7
2
4
6
3
9
2
3
3
8
3
7
9
5
3
9
5
1
1
3
0
9
1
0
1
3
7
8
4
5
8
1
2
2
3
1
4
1
5
3
4
9
3
7
3
5
9
5
7
4
5
3
5
3
8
8
1
4
1
5
3
3
7
9
4
8
7
9
6
1
7
7
3
1
5
2
6
6
6
4
0
8
3
9
2
7
8
0
0
4
5
8
2
5
7
7
7
1
1
0
5
7
3
3
3
4
8
2
8
9
8
6
9
1
5
6
8
0
8
1
4
8
4
8
7
2
7
7
7
7
3
5
9
7
1
9
7
1
9
5
0
4
7
9
5
6
0
7
7
3
1
5
5
3
6
5
4
7
5
8
9
4
9
5
1
0
8
2
3
3
0
9
1
4
0
6
8
4
2
4
3
4
6
4
3
9
3
8
2
3
3
9
0
1
4
8
7
0
8
3
7
4
9
6
2
2
3
8
4
1
5
2
8
5
2
0
6
4
0
8
6
2
9
3
3
6
5
1
7
7
3
2
5
1
0
5
3
6
4
4
4
5
3
9
3
1
8
0
5
1
3
8
3
2
9
9
8
3
3
1
2
1
9
1
4
4
8
8
5
2
1
4
8
2
8
0
8
9
3
4
4
1
9
3
1
7
7
9
7
8
8
5
4
3
0
3
3
0
6
1
5
3
7
0
3
4
8
1
2
3
3
2
0
5
1
8
9
8
7
2
6
9
3
6
2
3
1
3
1
7
0
2
5
4
6
3
7
0
3
6
5
7
1
1
8
4
3
4
5
0
9
6
6
5
1
8
4
8
9
4
4
3
6
3
6
6
7
0
3
6
2
4
4
0
3
9
2
9
4
4
0
0
5
6
9
8
3
7
1
6
8
5
7
8
9
1
6
3
6
4
0
9
9
6
9
7
5
6
7
0
8
5
8
8
2
3
4
4
8
0
8
0
6
3
4
3
4
4
0
7
8
1
7
4
7
1
2
8
8
7
1
6
3
7
8
5
0
5
4
5
4
5
9
2
2
5
7
2
3
0
0
8
5
8
9
7
5
6
6
8
8
6
3
8
9
8
1
0
4
8
8
9
4
3
6
0
4
1
2
8
0
1
7
6
4
3
5
2
7
8
1
5
8
7
3
0
5
2
7
8
8
4
8
6
5
4
3
3
6
7
6
4
1
4
9
2
8
5
0
5
1
2
6
7
4
5
0
1
8
0
3
8
2
4
5
8
9
7
8
[torch.LongTensor of size 800]

torch.max?

