In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
train_X = np.float32(np.random.normal(scale=2, size=(800, 69)))
train_y = np.float32(np.random.normal(scale=0.1, size=(800,1)))

In [3]:
class LinearRegression(nn.Module):
    
    def __init__(self, inputs, targets, learning_rate=1e-4):
        super(LinearRegression, self).__init__()
        self._train_X = inputs
        self._train_y = targets
        self._train_X_size = inputs.shape[1]
        self._train_y_size = targets.shape[1]
        self._learning_rate = learning_rate        
        self._linear = nn.Linear(self._train_X_size, self._train_y_size)
                
        # Loss and Optimizer
        self._loss_function = nn.MSELoss()
        self._optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)  
        
    def fit(self, training_epochs= 1e3, display= 1e2):
        display = np.int(display)
        for epoch in np.arange(np.int(training_epochs)):
            inputs = Variable(torch.from_numpy(self._train_X))
            targets = Variable(torch.from_numpy(self._train_y))
            self._optimizer.zero_grad() #清空所有被优化过的Variable的梯度.
            outputs = self._linear(inputs) # 使用神经网络架构前向推断
            self._loss = self._loss_function(outputs, targets) # 计算批次损失函数
            self._loss.backward() # 误差反向传播
            self._optimizer.step()
            
            if (epoch+1) % display == 0:
                print ('Epoch (%d/%d), loss:%.4f' %(epoch+1, training_epochs, self._loss.data[0]))            
                
    def pred(self, X):
        return self._linear(Variable(torch.from_numpy(X))).data.numpy()

In [4]:
a = LinearRegression(train_X, train_y)
a.fit()


Epoch (100/1000), loss:1.3313
Epoch (200/1000), loss:1.1228
Epoch (300/1000), loss:0.9495
Epoch (400/1000), loss:0.8052
Epoch (500/1000), loss:0.6846
Epoch (600/1000), loss:0.5837
Epoch (700/1000), loss:0.4991
Epoch (800/1000), loss:0.4280
Epoch (900/1000), loss:0.3680
Epoch (1000/1000), loss:0.3173

In [5]:
a.pred(train_X)


Out[5]:
array([[-0.46727881],
       [-0.2737641 ],
       [ 1.05788863],
       [-0.40368289],
       [ 0.9073351 ],
       [-0.33162621],
       [-0.18489531],
       [ 0.17736974],
       [ 0.94153154],
       [-1.37812138],
       [ 1.08444297],
       [ 0.56933361],
       [-0.93551493],
       [-0.58512652],
       [-0.0546658 ],
       [ 0.54009169],
       [-0.01503232],
       [ 0.7445389 ],
       [-0.51376849],
       [-0.39794806],
       [-0.42978182],
       [ 0.11550209],
       [ 0.16118822],
       [ 0.89003134],
       [ 0.08549362],
       [-0.12640905],
       [ 0.42096788],
       [ 0.68774188],
       [ 0.24705368],
       [-0.38662568],
       [-0.009137  ],
       [ 0.06563276],
       [ 0.31275877],
       [ 0.86082524],
       [-0.63355726],
       [ 0.58423811],
       [ 0.59983546],
       [-0.0220359 ],
       [ 0.63871819],
       [-0.67871124],
       [ 0.07216719],
       [-0.83627725],
       [ 1.23007071],
       [ 0.17995572],
       [ 0.82321936],
       [-0.23555687],
       [ 0.247641  ],
       [-0.47473851],
       [ 0.07626043],
       [-1.07636189],
       [-0.43494901],
       [ 0.36791253],
       [ 0.11122737],
       [-0.30446425],
       [ 0.68509686],
       [ 0.08989926],
       [ 0.04426361],
       [ 0.14791211],
       [-0.56430852],
       [-0.36163244],
       [-0.1383675 ],
       [-0.36726722],
       [-0.17415518],
       [ 0.70036316],
       [-0.62004799],
       [-0.23066962],
       [-0.71386451],
       [ 0.16563061],
       [-0.05870824],
       [ 0.53666085],
       [-0.26176867],
       [-0.33016458],
       [ 0.34658405],
       [ 0.60285509],
       [ 0.35613638],
       [-0.58566266],
       [-0.67799783],
       [ 0.00185735],
       [-0.55233228],
       [ 0.98762286],
       [-0.21431956],
       [ 0.5494563 ],
       [-0.21009275],
       [ 0.69777751],
       [-0.03421655],
       [ 0.37913439],
       [ 0.18968117],
       [-1.27394319],
       [-0.77346385],
       [-0.46639839],
       [ 0.23492968],
       [ 0.9172771 ],
       [ 0.15780166],
       [ 1.55143464],
       [-0.09927257],
       [-1.14908171],
       [ 0.28747687],
       [ 0.31212932],
       [-1.04667866],
       [ 0.16796905],
       [ 0.45362186],
       [ 0.62833196],
       [ 0.53346765],
       [ 0.32354438],
       [-0.37779015],
       [-0.61570638],
       [-0.05419575],
       [ 0.69903791],
       [ 0.59337538],
       [ 0.77665079],
       [ 0.11431374],
       [-0.13102683],
       [ 0.6272881 ],
       [ 0.47650045],
       [ 1.31695855],
       [-0.32094696],
       [-0.4573687 ],
       [ 1.04163384],
       [ 0.9533608 ],
       [ 0.47069952],
       [-0.5432148 ],
       [ 0.03210211],
       [ 0.22089702],
       [ 1.28343439],
       [-0.25111344],
       [ 0.22281212],
       [-1.10027289],
       [ 1.05092216],
       [-0.12315021],
       [ 1.23107111],
       [-0.12206171],
       [ 0.79479825],
       [-0.27658173],
       [-0.03257161],
       [-0.23320147],
       [ 0.02830369],
       [ 0.4720974 ],
       [ 0.96951437],
       [ 0.93682319],
       [ 1.13310838],
       [ 0.29869992],
       [ 0.09427438],
       [ 0.88912523],
       [ 0.42220753],
       [ 0.0372902 ],
       [-1.0529213 ],
       [-0.40246773],
       [ 0.4762809 ],
       [-0.68440276],
       [-0.84210074],
       [-0.33824477],
       [-0.56899309],
       [-0.09058844],
       [-0.40620026],
       [ 0.95020068],
       [-0.88972116],
       [-0.18232229],
       [-0.29647139],
       [ 0.37699977],
       [ 0.31088534],
       [-0.55553824],
       [ 0.73128766],
       [-0.31054783],
       [ 0.24725953],
       [-0.17362219],
       [ 0.3863723 ],
       [ 0.2109144 ],
       [-0.18588248],
       [ 1.01030123],
       [ 0.4374027 ],
       [-0.31382534],
       [-1.11601281],
       [ 1.06045115],
       [-0.70684767],
       [ 0.79756087],
       [-0.27063611],
       [ 0.02249322],
       [ 0.82708502],
       [-0.33988816],
       [ 1.57267392],
       [ 0.50086051],
       [-0.17204627],
       [-0.4206275 ],
       [ 0.20145682],
       [ 0.7343356 ],
       [ 0.71742195],
       [-0.03619666],
       [ 0.33785543],
       [ 0.67208284],
       [-0.84039855],
       [-0.22263804],
       [ 0.17220879],
       [ 0.30860427],
       [-0.29043838],
       [ 0.18191525],
       [-0.21931702],
       [ 0.07588533],
       [ 0.49802545],
       [ 0.33642587],
       [ 0.38563755],
       [ 0.63775098],
       [-0.40785185],
       [ 0.32480064],
       [ 0.43162766],
       [ 0.05904747],
       [-1.50012064],
       [ 0.67142713],
       [ 0.58435488],
       [ 0.31153658],
       [ 0.35355875],
       [-0.80490679],
       [ 0.35498881],
       [ 0.74034828],
       [ 0.57115614],
       [ 0.23331717],
       [ 0.76214534],
       [-0.84203553],
       [ 1.03372097],
       [-0.46389315],
       [ 0.13253121],
       [-0.1342862 ],
       [-0.37817308],
       [ 0.23515275],
       [ 0.16335416],
       [ 0.32894057],
       [ 1.26682103],
       [-1.42864013],
       [ 0.17241034],
       [-0.10295772],
       [ 0.50833976],
       [ 1.07998681],
       [ 0.33910474],
       [ 0.27489933],
       [-0.37901121],
       [-0.44782141],
       [-0.15007547],
       [-0.27165475],
       [ 0.18508086],
       [-0.18758163],
       [ 0.27747795],
       [ 0.2698271 ],
       [ 0.44019037],
       [-0.09881674],
       [-0.52289158],
       [-0.58016407],
       [-0.55769914],
       [ 0.35220537],
       [-0.71809584],
       [ 0.19504184],
       [ 0.32061931],
       [ 0.17160225],
       [ 0.13654795],
       [ 0.71740389],
       [ 0.30679858],
       [ 0.20361051],
       [-0.07986293],
       [ 1.49066842],
       [-0.62082833],
       [-0.18991417],
       [ 0.2118383 ],
       [-0.79518676],
       [ 0.65260118],
       [ 0.16324854],
       [-0.74143744],
       [-0.32836285],
       [-0.03312915],
       [ 0.04903696],
       [ 0.46051851],
       [-0.94917846],
       [-0.30457327],
       [ 0.47191548],
       [-0.56378162],
       [ 0.28875974],
       [-0.14943188],
       [ 0.22196722],
       [-0.29553759],
       [ 0.23299366],
       [ 0.06777838],
       [ 0.28480262],
       [ 0.40217543],
       [-0.08229988],
       [-0.56035197],
       [ 0.44685695],
       [ 0.26760721],
       [-0.4480128 ],
       [ 0.36270693],
       [ 0.39184597],
       [ 1.2296629 ],
       [ 0.27685833],
       [-0.29883829],
       [ 0.87571359],
       [ 0.11941351],
       [ 0.45306501],
       [-0.49373791],
       [-0.01900671],
       [-0.71475434],
       [-1.08297241],
       [ 0.38645425],
       [-0.11002251],
       [ 0.08480822],
       [ 0.60551834],
       [ 0.68000633],
       [-0.37446484],
       [ 0.28558043],
       [ 0.14444607],
       [ 0.74918902],
       [-0.70037574],
       [-0.19860438],
       [-0.18492687],
       [-0.2467089 ],
       [-0.2727268 ],
       [ 0.46970311],
       [ 1.63833928],
       [ 0.01708406],
       [ 0.07640977],
       [ 0.01324297],
       [-0.74184459],
       [ 0.02552059],
       [-0.05239459],
       [-0.09579279],
       [-1.18469524],
       [ 1.11568236],
       [-0.37881571],
       [-0.05398729],
       [-0.13281024],
       [-0.20784631],
       [-0.63487339],
       [ 0.11241669],
       [ 0.38291845],
       [-0.56777012],
       [ 0.47283524],
       [-0.57235891],
       [ 0.62945306],
       [ 0.78239769],
       [ 0.25560462],
       [-0.06890624],
       [-0.09904327],
       [ 0.72309995],
       [ 0.27688411],
       [-0.01648499],
       [-0.2638638 ],
       [ 0.04626647],
       [ 0.2585853 ],
       [-0.08818135],
       [-0.67361593],
       [-0.33928898],
       [-0.46787903],
       [ 0.71672267],
       [ 0.19206578],
       [ 0.36239332],
       [ 0.20416805],
       [-0.08599546],
       [-0.82836252],
       [ 0.11237525],
       [ 0.92147261],
       [ 0.614411  ],
       [-0.34799397],
       [ 0.68631375],
       [ 0.93303359],
       [ 0.94622296],
       [ 0.32473505],
       [ 0.12558074],
       [-0.75656539],
       [-0.43628553],
       [ 0.22439614],
       [ 0.81300449],
       [-0.24929121],
       [ 0.67036402],
       [-0.34255323],
       [ 0.33286998],
       [-0.22178271],
       [ 0.27287039],
       [ 0.31916785],
       [-0.38992885],
       [-0.44305053],
       [-0.31106797],
       [ 0.61986959],
       [-0.25370497],
       [-0.12899595],
       [-0.42957893],
       [ 0.58483469],
       [-0.46702805],
       [-0.49193183],
       [-0.13104644],
       [-0.39723346],
       [ 1.52816248],
       [-0.78577769],
       [ 0.46633014],
       [ 0.04623763],
       [-0.24196044],
       [-0.10803916],
       [-0.2881698 ],
       [ 0.63831007],
       [-0.82407862],
       [-0.01170271],
       [ 0.69322157],
       [-0.5017786 ],
       [ 1.18716002],
       [ 0.39446518],
       [-0.24519089],
       [ 0.74265879],
       [-0.31898609],
       [-0.12841737],
       [ 0.53461844],
       [ 0.08317027],
       [ 1.40966749],
       [ 0.17753935],
       [ 0.51147407],
       [-0.99975133],
       [ 0.12488622],
       [ 0.18601325],
       [ 0.66738188],
       [ 0.42723215],
       [-0.25254002],
       [ 0.07515737],
       [ 1.0625335 ],
       [-0.6224317 ],
       [-0.04716713],
       [ 0.73639989],
       [-0.07007905],
       [-0.08152183],
       [-0.26749712],
       [ 0.7328738 ],
       [ 0.63918918],
       [-0.41725296],
       [-0.04829331],
       [ 0.20069724],
       [-0.96875739],
       [ 0.02582369],
       [ 0.87910485],
       [ 0.65803856],
       [ 0.22482005],
       [-0.17080265],
       [ 0.60706478],
       [-0.58735651],
       [-0.3689813 ],
       [ 0.02913303],
       [-0.19729325],
       [ 0.05154093],
       [ 1.3757931 ],
       [-0.29170203],
       [ 0.10140788],
       [-0.84000373],
       [ 0.40076745],
       [ 0.4035421 ],
       [-0.84458983],
       [-0.13895413],
       [-0.40040359],
       [ 0.0197158 ],
       [-0.52464128],
       [ 0.26832265],
       [ 0.17876062],
       [ 0.06820909],
       [ 0.99942982],
       [ 0.36774647],
       [ 0.28691959],
       [-0.69095826],
       [ 0.31019875],
       [-0.19261858],
       [-0.54512566],
       [-0.97174621],
       [ 0.32773006],
       [ 0.72121447],
       [ 0.42360416],
       [-0.70742697],
       [ 0.49526089],
       [ 1.51969481],
       [-0.14510942],
       [ 0.47062743],
       [-0.0805038 ],
       [-0.62218893],
       [-0.15304625],
       [-0.18709144],
       [ 0.29036117],
       [-0.35462111],
       [ 0.79737014],
       [ 1.50713539],
       [ 0.14232478],
       [ 0.30431071],
       [-0.35210389],
       [-0.0099313 ],
       [ 0.08750501],
       [ 0.52289563],
       [-0.21382996],
       [ 0.86614323],
       [-0.3611007 ],
       [ 0.66799498],
       [-0.0742574 ],
       [ 0.42857921],
       [-0.37529054],
       [-0.54469341],
       [-0.44745681],
       [-0.94318914],
       [-0.43181208],
       [-0.68224454],
       [ 0.05364816],
       [ 0.48492202],
       [-0.26673377],
       [-0.65304083],
       [ 0.10603052],
       [ 0.46871355],
       [-0.79412782],
       [ 0.07105336],
       [-0.51376021],
       [-0.3923685 ],
       [ 0.27739224],
       [ 0.28864709],
       [-0.17036608],
       [-1.16720772],
       [ 0.06956331],
       [-0.00730576],
       [-0.71766394],
       [ 0.28128678],
       [ 0.1672478 ],
       [-0.02370133],
       [ 0.50171041],
       [-0.07009865],
       [ 0.01061276],
       [ 0.05633564],
       [-0.49955001],
       [-0.2817063 ],
       [-0.19101709],
       [-0.0033436 ],
       [ 1.12950516],
       [-0.5755657 ],
       [ 0.21608806],
       [-0.8256532 ],
       [ 0.9520039 ],
       [ 0.80697286],
       [-0.07756127],
       [-0.3151671 ],
       [ 0.14263672],
       [ 0.01480057],
       [-0.11259148],
       [ 0.90663183],
       [ 0.21596032],
       [ 0.98260522],
       [-0.49363783],
       [ 0.12438676],
       [ 0.06250332],
       [-0.43569633],
       [ 0.25145367],
       [ 0.35966235],
       [-0.70552313],
       [ 0.17478001],
       [ 0.10473324],
       [-0.16089395],
       [-0.60305721],
       [ 0.19422024],
       [ 0.09758363],
       [-0.43904236],
       [-0.54022515],
       [ 0.73896372],
       [ 0.6099056 ],
       [-0.00385395],
       [ 0.43497452],
       [ 0.38447818],
       [-0.80793869],
       [ 0.73171014],
       [-0.60187632],
       [-0.44840297],
       [-1.0055145 ],
       [-0.4168795 ],
       [-0.37266061],
       [ 0.20816579],
       [-0.14508697],
       [ 0.13987735],
       [-0.72863972],
       [-0.05385435],
       [-0.58654815],
       [-0.10666194],
       [ 0.22793192],
       [-0.38269076],
       [ 0.87929487],
       [ 0.83478677],
       [ 0.51511145],
       [-0.23796147],
       [ 1.36834955],
       [ 0.08172721],
       [-0.33726388],
       [ 0.55346268],
       [-0.42647991],
       [ 0.38795599],
       [ 0.56526667],
       [-0.45462832],
       [ 0.98716998],
       [-0.30336359],
       [ 0.37951115],
       [ 0.38037071],
       [ 0.10951657],
       [-0.2639437 ],
       [ 0.48513916],
       [ 0.46821287],
       [-0.28317851],
       [ 0.4840022 ],
       [ 1.04958975],
       [ 0.10377023],
       [ 0.10505842],
       [-0.02940079],
       [-0.42901555],
       [ 0.78063476],
       [ 0.45006946],
       [ 1.0737983 ],
       [ 0.52802682],
       [ 0.18564761],
       [ 1.72601175],
       [ 0.06325509],
       [ 0.3689467 ],
       [-0.8282966 ],
       [-1.15416896],
       [-0.32920754],
       [-0.89659929],
       [ 1.05815804],
       [ 0.70554912],
       [ 0.80964494],
       [ 0.06259336],
       [ 0.01603659],
       [ 0.32098967],
       [ 0.14421856],
       [ 0.0266166 ],
       [ 1.07772565],
       [-0.51597631],
       [ 0.74300241],
       [-0.33208385],
       [ 0.45798764],
       [ 0.11365047],
       [ 0.5227387 ],
       [ 0.79766619],
       [-0.24054846],
       [ 0.2758162 ],
       [-0.09917698],
       [-0.20623481],
       [-0.04829779],
       [-0.01329603],
       [ 0.47950909],
       [ 1.02781618],
       [ 0.69702411],
       [ 0.26320827],
       [-0.0665734 ],
       [ 0.13802579],
       [ 0.40244782],
       [ 1.01526988],
       [-0.21911451],
       [-0.38922921],
       [ 0.54782462],
       [ 0.3258568 ],
       [ 0.33680406],
       [ 0.87337738],
       [-0.03875498],
       [ 1.00650442],
       [ 0.85590285],
       [-0.25818923],
       [-0.01531341],
       [-0.61894882],
       [ 0.69424355],
       [-0.69726413],
       [-0.01645393],
       [-0.07206679],
       [-0.16152772],
       [ 0.22644481],
       [-0.20973408],
       [ 1.01885641],
       [ 1.05397451],
       [ 0.32664588],
       [ 0.40778533],
       [ 0.78050768],
       [-0.16207248],
       [-0.02656209],
       [ 0.60948062],
       [ 0.78253442],
       [-0.27905527],
       [ 0.85446668],
       [ 0.29925945],
       [-0.00957713],
       [-0.26511827],
       [-0.23545912],
       [ 0.70573199],
       [ 0.40559021],
       [ 0.11790828],
       [-0.20682126],
       [ 0.45438844],
       [ 0.82508808],
       [-0.4916046 ],
       [-0.04704569],
       [-0.38345656],
       [-0.07053103],
       [-0.36543491],
       [ 1.01872051],
       [ 0.43383664],
       [-0.67635405],
       [-0.20012966],
       [ 0.05783844],
       [ 0.20844081],
       [ 0.73643136],
       [ 0.52070045],
       [-0.43074551],
       [ 0.51860076],
       [ 0.39868322],
       [-0.07907497],
       [ 0.31996629],
       [ 0.66879386],
       [-0.12571037],
       [ 0.32119367],
       [ 0.96006119],
       [ 0.40066347],
       [-0.00539017],
       [-0.15629715],
       [ 0.96241993],
       [-0.19905522],
       [ 0.85837513],
       [ 0.80206525],
       [ 0.97351396],
       [-0.24246326],
       [-0.48889527],
       [-0.41130775],
       [ 0.15380764],
       [ 0.8356629 ],
       [-0.11419173],
       [ 0.65623331],
       [ 0.1342153 ],
       [-0.05273577],
       [-0.44047645],
       [-0.23370352],
       [-0.15489477],
       [ 0.03935406],
       [-0.19836313],
       [-0.99544287],
       [ 0.55117708],
       [-0.59885556],
       [-0.11648145],
       [ 0.71632475],
       [ 0.16874671],
       [-0.67526364],
       [-0.23234156],
       [ 0.17508182],
       [ 0.08158282],
       [-0.15935606],
       [-0.99951506],
       [ 0.28044924],
       [-0.15433758],
       [-0.62045538],
       [-0.79123241],
       [ 0.58547592],
       [ 0.61869878],
       [ 0.10728819],
       [ 0.43084788],
       [-0.1509831 ],
       [-0.18121251],
       [ 0.77610898],
       [ 0.42763799],
       [ 0.29294077],
       [-0.18773636],
       [-0.52435344],
       [ 0.05225879],
       [-0.48631665],
       [ 0.44314644],
       [ 0.7868011 ],
       [ 0.36901325],
       [ 0.40812942],
       [ 0.14549929],
       [-0.60122514],
       [ 0.03218338],
       [ 1.33894801],
       [ 0.25125071],
       [-0.95109797],
       [ 0.33375832],
       [ 0.59405452],
       [-0.34858713],
       [-0.30604652],
       [ 0.23997733],
       [ 0.02507449],
       [ 0.14913729],
       [-0.19059059],
       [-0.30665812],
       [ 0.56229717],
       [ 0.01064511],
       [-0.82806253],
       [-0.14076799],
       [-0.13117707],
       [ 0.79744756],
       [-0.13299817],
       [ 0.06150784],
       [ 0.96350777],
       [-0.08150166],
       [ 0.35942987],
       [-0.87045056],
       [ 0.35863313],
       [-0.19681925],
       [ 0.3388117 ],
       [-0.49157837],
       [-0.25832838],
       [ 0.08012209],
       [ 0.6801765 ],
       [ 0.6910466 ],
       [ 0.18716562],
       [ 0.9650228 ],
       [ 0.58523625],
       [-0.61542195],
       [ 0.65314984],
       [-0.3882488 ]], dtype=float32)