In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
train_X = np.random.normal(scale=2, size=(800, 69))
train_y = np.random.randint(0, high=10, size=(800,1), dtype=np.int64)

In [3]:
class LogistRegression(nn.Module): #  所有网络的基类。
    
    def __init__(self, inputs, targets, targets_size=10, learning_rate=1e-4):
        super().__init__() # 在子类中调用父类的初始化方法
        self._train_X = inputs
        self._train_y = targets
        self._train_X_size = inputs.shape[1]
        self._train_y_size = targets_size
        self._learning_rate = learning_rate    
        
        self._linear = nn.Linear(self._train_X_size, self._train_y_size)
        self._loss_function = nn.CrossEntropyLoss()
        self._optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        
        
    def fit(self, training_epochs= 1e3, display= 1e2):
        display = np.int(display)
        for epoch in np.arange(np.int(training_epochs)):
            inputs = Variable(torch.FloatTensor(self._train_X),requires_grad=True)
            targets = Variable(torch.LongTensor(self._train_y.flatten()))
            self._optimizer.zero_grad() #清空所有被优化过的Variable的梯度.
            outputs = self._linear(inputs) # 使用神经网络架构前向推断
            self._loss = self._loss_function(outputs, targets) # 计算批次损失函数
            self._loss.backward() # 误差反向传播
            self._optimizer.step()
            
            if (epoch+1) % display == 0:
                print ('Epoch (%d/%d), loss:%.4f' %(epoch+1, training_epochs, self._loss.data[0]))    
    
    def pred(self, X):
        outputs = self._linear(Variable(torch.FloatTensor(X)))
        _, output_labels  = torch.max(outputs, 1)
        return output_labels
            
a = LogistRegression(train_X, train_y, 10)
a.fit(1e5, 1e3)


Epoch (1000/100000), loss:2.8511
Epoch (2000/100000), loss:2.7817
Epoch (3000/100000), loss:2.7173
Epoch (4000/100000), loss:2.6575
Epoch (5000/100000), loss:2.6021
Epoch (6000/100000), loss:2.5507
Epoch (7000/100000), loss:2.5032
Epoch (8000/100000), loss:2.4591
Epoch (9000/100000), loss:2.4184
Epoch (10000/100000), loss:2.3806
Epoch (11000/100000), loss:2.3457
Epoch (12000/100000), loss:2.3133
Epoch (13000/100000), loss:2.2834
Epoch (14000/100000), loss:2.2556
Epoch (15000/100000), loss:2.2299
Epoch (16000/100000), loss:2.2061
Epoch (17000/100000), loss:2.1840
Epoch (18000/100000), loss:2.1636
Epoch (19000/100000), loss:2.1446
Epoch (20000/100000), loss:2.1270
Epoch (21000/100000), loss:2.1106
Epoch (22000/100000), loss:2.0954
Epoch (23000/100000), loss:2.0813
Epoch (24000/100000), loss:2.0681
Epoch (25000/100000), loss:2.0559
Epoch (26000/100000), loss:2.0446
Epoch (27000/100000), loss:2.0340
Epoch (28000/100000), loss:2.0241
Epoch (29000/100000), loss:2.0149
Epoch (30000/100000), loss:2.0063
Epoch (31000/100000), loss:1.9983
Epoch (32000/100000), loss:1.9908
Epoch (33000/100000), loss:1.9838
Epoch (34000/100000), loss:1.9772
Epoch (35000/100000), loss:1.9711
Epoch (36000/100000), loss:1.9653
Epoch (37000/100000), loss:1.9599
Epoch (38000/100000), loss:1.9548
Epoch (39000/100000), loss:1.9501
Epoch (40000/100000), loss:1.9456
Epoch (41000/100000), loss:1.9414
Epoch (42000/100000), loss:1.9374
Epoch (43000/100000), loss:1.9336
Epoch (44000/100000), loss:1.9301
Epoch (45000/100000), loss:1.9268
Epoch (46000/100000), loss:1.9237
Epoch (47000/100000), loss:1.9207
Epoch (48000/100000), loss:1.9179
Epoch (49000/100000), loss:1.9152
Epoch (50000/100000), loss:1.9127
Epoch (51000/100000), loss:1.9103
Epoch (52000/100000), loss:1.9080
Epoch (53000/100000), loss:1.9059
Epoch (54000/100000), loss:1.9039
Epoch (55000/100000), loss:1.9019
Epoch (56000/100000), loss:1.9001
Epoch (57000/100000), loss:1.8983
Epoch (58000/100000), loss:1.8967
Epoch (59000/100000), loss:1.8951
Epoch (60000/100000), loss:1.8936
Epoch (61000/100000), loss:1.8921
Epoch (62000/100000), loss:1.8907
Epoch (63000/100000), loss:1.8894
Epoch (64000/100000), loss:1.8882
Epoch (65000/100000), loss:1.8870
Epoch (66000/100000), loss:1.8858
Epoch (67000/100000), loss:1.8847
Epoch (68000/100000), loss:1.8837
Epoch (69000/100000), loss:1.8826
Epoch (70000/100000), loss:1.8817
Epoch (71000/100000), loss:1.8808
Epoch (72000/100000), loss:1.8799
Epoch (73000/100000), loss:1.8790
Epoch (74000/100000), loss:1.8782
Epoch (75000/100000), loss:1.8774
Epoch (76000/100000), loss:1.8767
Epoch (77000/100000), loss:1.8759
Epoch (78000/100000), loss:1.8752
Epoch (79000/100000), loss:1.8746
Epoch (80000/100000), loss:1.8739
Epoch (81000/100000), loss:1.8733
Epoch (82000/100000), loss:1.8727
Epoch (83000/100000), loss:1.8721
Epoch (84000/100000), loss:1.8716
Epoch (85000/100000), loss:1.8711
Epoch (86000/100000), loss:1.8705
Epoch (87000/100000), loss:1.8700
Epoch (88000/100000), loss:1.8696
Epoch (89000/100000), loss:1.8691
Epoch (90000/100000), loss:1.8687
Epoch (91000/100000), loss:1.8682
Epoch (92000/100000), loss:1.8678
Epoch (93000/100000), loss:1.8674
Epoch (94000/100000), loss:1.8670
Epoch (95000/100000), loss:1.8666
Epoch (96000/100000), loss:1.8663
Epoch (97000/100000), loss:1.8659
Epoch (98000/100000), loss:1.8656
Epoch (99000/100000), loss:1.8653
Epoch (100000/100000), loss:1.8649

In [4]:
a.pred(train_X)


Out[4]:
Variable containing:
 3
 5
 4
 9
 4
 4
 0
 8
 3
 8
 0
 3
 2
 6
 6
 6
 7
 0
 8
 8
 8
 5
 5
 5
 4
 1
 8
 8
 3
 6
 6
 0
 1
 1
 7
 7
 3
 9
 1
 1
 5
 6
 3
 1
 6
 1
 3
 2
 3
 4
 3
 4
 1
 1
 6
 4
 8
 6
 6
 6
 3
 6
 0
 7
 8
 3
 6
 7
 4
 6
 8
 0
 5
 3
 3
 5
 4
 1
 7
 2
 1
 7
 4
 0
 6
 6
 6
 7
 7
 1
 7
 0
 5
 3
 3
 7
 9
 0
 9
 6
 3
 3
 6
 1
 2
 8
 4
 1
 5
 4
 2
 8
 8
 4
 7
 3
 9
 5
 8
 1
 4
 0
 5
 7
 8
 2
 6
 2
 3
 8
 3
 7
 4
 5
 3
 3
 1
 8
 5
 4
 4
 9
 0
 4
 4
 7
 1
 3
 5
 9
 6
 7
 4
 9
 5
 8
 2
 4
 1
 4
 4
 3
 8
 4
 7
 5
 3
 1
 9
 6
 1
 5
 1
 6
 3
 8
 0
 8
 6
 1
 9
 9
 8
 9
 0
 4
 5
 1
 9
 3
 2
 0
 4
 0
 1
 8
 5
 4
 2
 3
 2
 1
 1
 4
 9
 0
 2
 3
 3
 0
 5
 6
 1
 7
 9
 4
 0
 5
 0
 9
 3
 0
 8
 2
 7
 9
 3
 0
 8
 0
 0
 6
 9
 0
 8
 6
 4
 1
 7
 4
 8
 8
 6
 5
 8
 8
 1
 0
 9
 5
 4
 7
 2
 4
 6
 3
 9
 2
 3
 3
 8
 3
 7
 9
 5
 3
 9
 5
 1
 1
 3
 0
 9
 1
 0
 1
 3
 7
 8
 4
 5
 8
 1
 2
 2
 3
 1
 4
 1
 5
 3
 4
 9
 3
 7
 3
 5
 9
 5
 7
 4
 5
 3
 5
 3
 8
 8
 1
 4
 1
 5
 3
 3
 7
 9
 4
 8
 7
 9
 6
 1
 7
 7
 3
 1
 5
 2
 6
 6
 6
 4
 0
 8
 3
 9
 2
 7
 8
 0
 0
 4
 5
 8
 2
 5
 7
 7
 7
 1
 1
 0
 5
 7
 3
 3
 3
 4
 8
 2
 8
 9
 8
 6
 9
 1
 5
 6
 8
 0
 8
 1
 4
 8
 4
 8
 7
 2
 7
 7
 7
 7
 3
 5
 9
 7
 1
 9
 7
 1
 9
 5
 0
 4
 7
 9
 5
 6
 0
 7
 7
 3
 1
 5
 5
 3
 6
 5
 4
 7
 5
 8
 9
 4
 9
 5
 1
 0
 8
 2
 3
 3
 0
 9
 1
 4
 0
 6
 8
 4
 2
 4
 3
 4
 6
 4
 3
 9
 3
 8
 2
 3
 3
 9
 0
 1
 4
 8
 7
 0
 8
 3
 7
 4
 9
 6
 2
 2
 3
 8
 4
 1
 5
 2
 8
 5
 2
 0
 6
 4
 0
 8
 6
 2
 9
 3
 3
 6
 5
 1
 7
 7
 3
 2
 5
 1
 0
 5
 3
 6
 4
 4
 4
 5
 3
 9
 3
 1
 8
 0
 5
 1
 3
 8
 3
 2
 9
 9
 8
 3
 3
 1
 2
 1
 9
 1
 4
 4
 8
 8
 5
 2
 1
 4
 8
 2
 8
 0
 8
 9
 3
 4
 4
 1
 9
 3
 1
 7
 7
 9
 7
 8
 8
 5
 4
 3
 0
 3
 3
 0
 6
 1
 5
 3
 7
 0
 3
 4
 8
 1
 2
 3
 3
 2
 0
 5
 1
 8
 9
 8
 7
 2
 6
 9
 3
 6
 2
 3
 1
 3
 1
 7
 0
 2
 5
 4
 6
 3
 7
 0
 3
 6
 5
 7
 1
 1
 8
 4
 3
 4
 5
 0
 9
 6
 6
 5
 1
 8
 4
 8
 9
 4
 4
 3
 6
 3
 6
 6
 7
 0
 3
 6
 2
 4
 4
 0
 3
 9
 2
 9
 4
 4
 0
 0
 5
 6
 9
 8
 3
 7
 1
 6
 8
 5
 7
 8
 9
 1
 6
 3
 6
 4
 0
 9
 9
 6
 9
 7
 5
 6
 7
 0
 8
 5
 8
 8
 2
 3
 4
 4
 8
 0
 8
 0
 6
 3
 4
 3
 4
 4
 0
 7
 8
 1
 7
 4
 7
 1
 2
 8
 8
 7
 1
 6
 3
 7
 8
 5
 0
 5
 4
 5
 4
 5
 9
 2
 2
 5
 7
 2
 3
 0
 0
 8
 5
 8
 9
 7
 5
 6
 6
 8
 8
 6
 3
 8
 9
 8
 1
 0
 4
 8
 8
 9
 4
 3
 6
 0
 4
 1
 2
 8
 0
 1
 7
 6
 4
 3
 5
 2
 7
 8
 1
 5
 8
 7
 3
 0
 5
 2
 7
 8
 8
 4
 8
 6
 5
 4
 3
 3
 6
 7
 6
 4
 1
 4
 9
 2
 8
 5
 0
 5
 1
 2
 6
 7
 4
 5
 0
 1
 8
 0
 3
 8
 2
 4
 5
 8
 9
 7
 8
[torch.LongTensor of size 800]

In [5]:
torch.max?