In [1]:
# Add anna to the path
import os
import sys
module_path = os.path.abspath(os.path.join("../../../anna"))
if module_path not in sys.path:
    sys.path.append(module_path)

DATA_DIR = "../../../data"

In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import data.dataset.reuters21578 as data

%matplotlib inline

In [3]:
# Load data
train_docs, test_docs, unused_docs, labels = data.parse(data.fetch(DATA_DIR))

In [4]:
labels_count = len(set(labels))
labels_per_doc = len(labels) / (len(train_docs) + len(test_docs))
print("# Train docs: " + str(len(train_docs)))
print("# Test docs: " + str(len(test_docs)))
print("# Labels: " + str(labels_count))
print("# Labels per doc: " + str(labels_per_doc))


# Train docs: 7770
# Test docs: 3019
# Labels: 90
# Labels per doc: 0.008341829641301325

In [5]:
pre, ax = plt.subplots(figsize=[12, 6])

ax.set_xlabel('# Labels')
ax.set_ylabel('# Instances')
n, bins, patches = ax.hist([len(d.labels) for d in test_docs])



In [6]:
for doc in train_docs[:10]:
    print("Title: {}\nLabels: {}\n".format(doc.title, doc.labels))


Title: BAHIA COCOA REVIEW
Labels: ['cocoa']

Title: NATIONAL AVERAGE PRICES FOR FARMER-OWNED RESERVE
Labels: ['grain', 'wheat', 'corn', 'barley', 'oat', 'sorghum']

Title: ARGENTINE 1986/87 GRAIN/OILSEED REGISTRATIONS
Labels: ['veg-oil', 'lin-oil', 'soy-oil', 'sun-oil', 'soybean', 'oilseed', 'corn', 'sunseed', 'grain', 'sorghum', 'wheat']

Title: CHAMPION PRODUCTS <CH> APPROVES STOCK SPLIT
Labels: ['earn']

Title: COMPUTER TERMINAL SYSTEMS <CPML> COMPLETES SALE
Labels: ['acq']

Title: COBANCO INC <CBCO> YEAR NET
Labels: ['earn']

Title: OHIO MATTRESS <OMT> MAY HAVE LOWER 1ST QTR NET
Labels: ['earn', 'acq']

Title: AM INTERNATIONAL INC <AM> 2ND QTR JAN 31
Labels: ['earn']

Title: BROWN-FORMAN INC <BFD> 4TH QTR NET
Labels: ['earn']

Title: DEAN FOODS <DF> SEES STRONG 4TH QTR EARNINGS
Labels: ['earn']