Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

from datascience import *
import numpy as np
import matplotlib

%matplotlib inline
import matplotlib.pyplot as plots
plots.style.use('fivethirtyeight')

import warnings
warnings.simplefilter("ignore")

Review of the Steps in Classification & Functions

  • distance(pt1, pt2): Returns the distance between the arrays pt1 and pt2

  • row_distance(row1, row2): Returns the distance between the rows row1 and row2

  • distances(training, example): Returns a table that is training with an additional column 'Distance' that contains the distance between example and each row of training

  • closest(training, example, k): Returns a table of the rows corresponding to the k smallest distances

  • majority_class(topk): Returns the majority class in the 'Class' column

  • classify(training, example, k): Returns the predicted class of example based on a k nearest neighbors classifier using the historical sample training

  • classify_all(training, test, k): Return the test table with a Prediction column that results from calling classify on each test example.

  • get_accuracy(t, prediction_label='Prediction'): Return the accuracy, which is the fraction of values in the Prediction column that match the Class column.

  • evaluate_accuracy(training, test, k): Classify all rows of the test set and return the accuracy.

from tqdm.notebook import tqdm  # This generates animated progress bars

def distance(pt1, pt2):
    """Return the distance between two points, represented as arrays"""
    return np.sqrt(sum((pt1 - pt2)**2))

def row_distance(row1, row2):
    """Return the distance between two numerical rows of a table"""
    return distance(np.array(row1), np.array(row2))

def distances(training, example):
    """
    Compute distance between example and every row in training.
    Return training augmented with Distance column
    """
    distances = make_array()
    attributes_only = training.drop('Class')
    
    for row in attributes_only.rows:
        distances = np.append(distances, row_distance(row, example))
            
    return training.with_column('Distance_to_ex', distances)

def closest(training, example, k):
    """
    Return a table of the k closest neighbors to example
    """
    return distances(training, example).sort('Distance_to_ex').take(np.arange(k))

def majority_class(topk):
    """
    Return the class with the highest count
    """
    return topk.group('Class').sort('count', descending=True).column(0).item(0)

def classify(training, example, k):
    """
    Return the majority class among the 
    k nearest neighbors of example
    """
    return majority_class(closest(training, example, k))

def classify_all(training, test, k):
    """Classify each row of the test table and add a column of the results."""
    test_attributes = test.drop('Class')
    guesses = make_array()
    for i in tqdm(np.arange(test.num_rows)):
        c = classify(training, test_attributes.row(i), k)
        guesses = np.append(guesses, c)
    return test.with_column("Prediction", guesses)

def get_accuracy(t, prediction_label='Prediction'):
    """Return the accuracy on a test table with Class and Prediction columns."""
    return sum(t.column('Class') == t.column(prediction_label)) / t.num_rows

def evaluate_accuracy(training, test, k):
    """Return the proportion of correctly classified examples 
    in the test set"""
    return get_accuracy(classify_all(training, test, k))

Text Classification

SMS Spam

from datasets import load_dataset

sms = load_dataset('ucirvine/sms_spam', split='train').shuffle(seed=42)
sms_texts = np.array(sms['sms'])
sms_labels = np.array(sms['label'])

sms_tbl = Table().with_columns('Text', sms_texts, 'Class', sms_labels)
sms_tbl.group('Class').show()
Loading...
sms_tbl.where('Class', 1).sample(with_replacement=False).show(5)
Loading...
sms_tbl.where('Class', 0).sample(with_replacement=False).show(5)
Loading...
texts = sms_tbl.column('Text')

sms_data = Table().with_columns(
    'Chars', np.char.str_len(texts),
    'Digits', sum(np.char.count(texts, str(d)) for d in range(10)),
    'Caps', sum(np.char.count(texts, chr(c)) for c in range(65, 91)),
    'Exclamations', np.char.count(texts, '!'),
    'Class', sms_tbl.column('Class')
)
sms_data
Loading...
sms_data.scatter('Digits', 'Caps', group='Class')
<Figure size 500x500 with 1 Axes>
shuffled = sms_data.sample(with_replacement=False)
test_size = 100
train_sms = shuffled.take(np.arange(test_size, shuffled.num_rows))
test_sms = shuffled.take(np.arange(test_size))

print('Training:', train_sms.num_rows, ' Test:', test_sms.num_rows)
evaluate_accuracy(train_sms, test_sms, 5)
Training: 5474  Test: 100
Loading...
0.97999999999999998

Rotten Tomatoes Movie Reviews

reviews_full = load_dataset('rotten_tomatoes', split='train')
reviews_short = reviews_full.filter(lambda x: 5 <= len(x['text'].split()) <= 10)

reviews = Table().with_columns('Text', reviews_short['text'],
                               'Class', reviews_short['label'])
reviews = reviews.sample(with_replacement=False)  # Permute the rows
reviews.group('Class')
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading...
reviews.sample(5)
Loading...
words = [  # The most common adjectives in the data
    'good', 'bad', 'funny', 'little', 'much', 'new', 'best',
    'many', 'own', 'other', 'big', 'great', 'most', 'few',
    'real', 'first', 'full', 'american', 'romantic', 'same', 'old',
    'better', 'young', 'original', 'interesting', 'human',
    'hard', 'cinematic', 'enough', 'emotional', 'last', 'least', 'long',
    'true', 'predictable', 'visual', 'whole', 'high', 'special',
    'entertaining', 'sweet', 'enjoyable', 'narrative', 'familiar'
]
counts = Table(['Word', 'Positive', 'Negative'])
for word in words:
    has_word = reviews.where('Text', are.containing(word))
    counts = counts.with_row([word, has_word.where('Class', 1).num_rows,
                                    has_word.where('Class', 0).num_rows])

counts
Loading...
reviews.where('Text', are.containing('funny')).where('Class', 0).sample(5, with_replacement=False)
Loading...
texts = reviews.column('Text')
review_words = Table().with_column('Class', reviews.column('Class'))
for word in words:
    review_words = review_words.with_column(word, np.char.count(np.char.lower(texts), word))

review_words.sample(5)
Loading...
train_reviews = review_words.take(np.arange(test_size, reviews.num_rows))
test_reviews = review_words.take(np.arange(test_size))

print('Word-count KNN:')
evaluate_accuracy(train_reviews, test_reviews, 5)
Word-count KNN:
Loading...
0.57999999999999996
classify_all(train_reviews, test_reviews, 5).pivot('Prediction', 'Class')
Loading...
Loading...

Sentence Embeddings

Data 8 students are not responsible for learning the details of the code below, such as how to call embedder.encode or understanding what it does.

from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer('all-MiniLM-L6-v2')
review_emb = embedder.encode(list(reviews.column('Text')), show_progress_bar=True)
print('Embedding shape:', review_emb.shape)
Loading...
BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED:	can be ignored when loading from different task/architecture; not ok if you expect identical arch.
Loading...
Embedding shape: (1072, 384)
n_features = 384  # Increasing this will help, but above 128 datahub will crash

cols = ['Class', reviews.column('Class')]
for i in range(n_features):
    cols += [f'Embed{i}', review_emb[:, i]]

review_emb_table = Table().with_columns(*cols)
review_emb_table.row(0)
Row(Class=1, Embed0=0.019200735, Embed1=-0.033960462, Embed2=-0.0064136791, Embed3=0.03181849, Embed4=-0.040090643, Embed5=0.094549574, Embed6=0.052497644, Embed7=-0.0090332543, Embed8=0.14175008, Embed9=-0.040151011, Embed10=0.018321985, Embed11=-0.0025268409, Embed12=0.10903085, Embed13=-0.022248896, Embed14=-0.072498359, Embed15=0.0020178349, Embed16=-0.096268445, Embed17=-0.14937204, Embed18=0.064575993, Embed19=-0.011026557, Embed20=-0.0015952713, Embed21=0.11924704, Embed22=0.091237433, Embed23=-0.0199525, Embed24=0.033193659, Embed25=-0.051121421, Embed26=-0.024570983, Embed27=-0.01853261, Embed28=0.0062042526, Embed29=0.015071545, Embed30=-0.040215664, Embed31=0.055361636, Embed32=0.066318102, Embed33=0.031858876, Embed34=-0.018447991, Embed35=-0.039461486, Embed36=0.041817583, Embed37=-0.04022523, Embed38=0.0095656691, Embed39=0.04338371, Embed40=0.046871517, Embed41=-0.035502426, Embed42=0.041842945, Embed43=0.0083235353, Embed44=-0.053895228, Embed45=0.10517883, Embed46=-0.054763012, Embed47=-0.01858039, Embed48=-0.032416202, Embed49=0.050635505, Embed50=0.051845279, Embed51=0.14508191, Embed52=-0.022398172, Embed53=0.017713087, Embed54=0.095303684, Embed55=0.045846161, Embed56=0.031230815, Embed57=-0.012224013, Embed58=-0.018359009, Embed59=0.061371267, Embed60=-0.011415593, Embed61=0.0055315895, Embed62=-0.027610051, Embed63=0.0038601193, Embed64=-0.072858371, Embed65=-0.025136555, Embed66=-0.055012017, Embed67=-0.028078653, Embed68=0.017812012, Embed69=0.057447948, Embed70=-0.032159958, Embed71=0.048654668, Embed72=0.11205786, Embed73=0.062967733, Embed74=-0.069438316, Embed75=-0.034947701, Embed76=0.0038991128, Embed77=-0.051407687, Embed78=0.051953442, Embed79=0.0041363183, Embed80=0.011574274, Embed81=-0.091281667, Embed82=-0.013339411, Embed83=0.0095055038, Embed84=0.025257861, Embed85=-0.014197022, Embed86=0.024549646, Embed87=-0.05214956, Embed88=0.025298683, Embed89=-0.029525256, Embed90=-0.066384219, Embed91=0.0037996785, Embed92=0.011668861, Embed93=0.029969702, Embed94=-0.013003308, Embed95=-0.014168013, Embed96=-0.035101067, Embed97=0.16861662, Embed98=-0.014492441, Embed99=0.097652614, Embed100=-0.057462871, Embed101=-0.040398788, Embed102=-0.025071179, Embed103=0.043633301, Embed104=-0.047790032, Embed105=-0.010052725, Embed106=-0.01096565, Embed107=0.012094741, Embed108=-0.024276933, Embed109=-0.051693216, Embed110=0.034855492, Embed111=-0.032485507, Embed112=0.032803603, Embed113=0.075656474, Embed114=0.0027889144, Embed115=-0.099236466, Embed116=0.036478061, Embed117=-0.057118833, Embed118=0.085922398, Embed119=-0.0079054935, Embed120=0.030957622, Embed121=0.044843595, Embed122=0.00047890045, Embed123=-0.04620802, Embed124=0.0075430772, Embed125=-0.037438426, Embed126=0.04486056, Embed127=-9.5167486e-33, Embed128=0.039664309, Embed129=-0.019245725, Embed130=0.044307977, Embed131=0.054080449, Embed132=0.026630839, Embed133=0.027727518, Embed134=-0.014169834, Embed135=0.015142147, Embed136=-0.016158871, Embed137=-0.027471203, Embed138=0.044244338, Embed139=-0.0026568449, Embed140=0.036940068, Embed141=0.037486009, Embed142=0.020068616, Embed143=-0.013498004, Embed144=0.037955794, Embed145=0.042284511, Embed146=0.037065268, Embed147=-0.044414748, Embed148=-0.017840335, Embed149=-0.095402896, Embed150=0.023636816, Embed151=-0.02077104, Embed152=-0.049024913, Embed153=0.020972947, Embed154=-0.061179101, Embed155=0.023439497, Embed156=-0.022836359, Embed157=-0.0040775314, Embed158=0.081237651, Embed159=0.013401201, Embed160=-0.039235558, Embed161=0.0035797185, Embed162=0.058674473, Embed163=-0.054377079, Embed164=0.070006497, Embed165=0.0096483352, Embed166=-0.011844392, Embed167=0.016547123, Embed168=-0.019024095, Embed169=-0.013837093, Embed170=-0.0171184, Embed171=-0.026276613, Embed172=0.011452054, Embed173=0.02993915, Embed174=0.081330098, Embed175=0.012481669, Embed176=-0.090216488, Embed177=0.0321934, Embed178=0.044502884, Embed179=0.011759531, Embed180=0.073093362, Embed181=0.006523449, Embed182=-0.057302188, Embed183=-0.019554673, Embed184=-0.022485184, Embed185=0.0068438458, Embed186=-0.043456648, Embed187=0.10746414, Embed188=0.0044367905, Embed189=-0.013948614, Embed190=0.02568165, Embed191=-0.029072253, Embed192=0.01319675, Embed193=0.079375818, Embed194=-0.0014680952, Embed195=0.067466237, Embed196=0.018867096, Embed197=0.070207953, Embed198=0.013850144, Embed199=-0.027475124, Embed200=-0.067408644, Embed201=-0.0069036987, Embed202=-0.0084751742, Embed203=-0.040004268, Embed204=0.12173278, Embed205=-0.028197052, Embed206=0.034309275, Embed207=0.060053106, Embed208=-0.060335215, Embed209=-0.080438673, Embed210=0.036902152, Embed211=-0.056579653, Embed212=0.081074752, Embed213=0.030882549, Embed214=0.011232291, Embed215=-0.038016863, Embed216=-0.0080078105, Embed217=-0.03542104, Embed218=0.01125655, Embed219=-0.041605588, Embed220=0.022010665, Embed221=-0.10041232, Embed222=-0.103029, Embed223=6.1037765e-33, Embed224=-0.088019446, Embed225=0.041582733, Embed226=0.017738488, Embed227=0.067682378, Embed228=0.085441992, Embed229=-0.076870069, Embed230=-0.01234725, Embed231=0.047427919, Embed232=0.0047961017, Embed233=0.066875286, Embed234=0.0066338289, Embed235=0.088315584, Embed236=0.099424116, Embed237=0.0043034237, Embed238=0.062203418, Embed239=-0.031534575, Embed240=-0.0033526367, Embed241=-0.056022335, Embed242=-0.042807195, Embed243=-8.4717751e-05, Embed244=0.022036508, Embed245=0.048915867, Embed246=-0.016679762, Embed247=0.0051740333, Embed248=-0.094783209, Embed249=0.077816412, Embed250=-0.037696257, Embed251=-0.092922889, Embed252=0.028654244, Embed253=-0.033307683, Embed254=0.082361378, Embed255=-0.056337703, Embed256=-0.049962524, Embed257=-0.044098556, Embed258=0.018897027, Embed259=-0.0061969468, Embed260=0.094167061, Embed261=0.004905649, Embed262=-0.029835809, Embed263=0.0051002661, Embed264=-0.046255313, Embed265=-0.052325688, Embed266=-0.010768991, Embed267=-0.0022940973, Embed268=-0.03541822, Embed269=0.042871132, Embed270=0.079488471, Embed271=0.03752647, Embed272=-0.044480648, Embed273=-0.0043271217, Embed274=-0.081752613, Embed275=-0.0025551782, Embed276=-0.073939279, Embed277=0.057892434, Embed278=-0.11598612, Embed279=0.00034971733, Embed280=-0.082671866, Embed281=-0.032798074, Embed282=0.044826441, Embed283=0.026073035, Embed284=-0.079576336, Embed285=0.0092876023, Embed286=-0.047791652, Embed287=0.053946298, Embed288=0.050642405, Embed289=0.017061325, Embed290=-0.021587756, Embed291=-0.036897097, Embed292=-0.097359911, Embed293=0.044636872, Embed294=-0.0039956463, Embed295=-0.025734246, Embed296=-0.19191447, Embed297=0.01307487, Embed298=-0.054877736, Embed299=0.014316224, Embed300=-0.026525373, Embed301=-0.020647364, Embed302=-0.063316047, Embed303=0.054835387, Embed304=-0.094343841, Embed305=-0.13981624, Embed306=0.0074783657, Embed307=-0.024427701, Embed308=-0.011857391, Embed309=-0.029126687, Embed310=0.10145913, Embed311=0.030104639, Embed312=-0.020083241, Embed313=0.033763751, Embed314=-0.053282749, Embed315=-0.076698445, Embed316=-0.0042699627, Embed317=0.043175731, Embed318=-0.031464927, Embed319=-2.4114755e-08, Embed320=-0.0048140953, Embed321=0.014067418, Embed322=-0.04238341, Embed323=-0.060727961, Embed324=0.044613361, Embed325=0.014248077, Embed326=0.031102734, Embed327=0.0047774459, Embed328=-0.038913958, Embed329=0.035548653, Embed330=0.081405006, Embed331=-0.0018778802, Embed332=-0.010515918, Embed333=0.055770621, Embed334=0.029296527, Embed335=-0.0039104316, Embed336=0.028494176, Embed337=0.032083746, Embed338=-0.024760615, Embed339=0.044377938, Embed340=-0.044858433, Embed341=0.018080927, Embed342=0.05483254, Embed343=-0.04167179, Embed344=-0.049858958, Embed345=0.024258351, Embed346=-0.020541422, Embed347=0.1049445, Embed348=0.014999316, Embed349=-0.049611948, Embed350=0.04000701, Embed351=0.035645567, Embed352=-0.09615761, Embed353=0.029795587, Embed354=-0.0080368947, Embed355=-0.0012237617, Embed356=-0.070563138, Embed357=-0.043468066, Embed358=-0.028659062, Embed359=-0.021974454, Embed360=-0.057645734, Embed361=-0.0092698643, Embed362=-0.041177511, Embed363=0.14451887, Embed364=-0.0080398927, Embed365=0.028745366, Embed366=-0.084894128, Embed367=-0.029856678, Embed368=-0.03837274, Embed369=-0.016820457, Embed370=-0.0023529835, Embed371=0.012338948, Embed372=0.022666285, Embed373=0.0073495391, Embed374=0.081027202, Embed375=0.022812029, Embed376=-0.00019680356, Embed377=-0.079870604, Embed378=-0.016961507, Embed379=-0.052322343, Embed380=0.017594513, Embed381=-0.041767381, Embed382=0.048701175, Embed383=-0.066063717)
train = review_emb_table.take(np.arange(test_size, reviews.num_rows))
test = review_emb_table.take(np.arange(test_size))
evaluate_accuracy(train, test, 5)
Loading...
0.78000000000000003
classify_all(train, test, 5).pivot('Prediction', 'Class')
Loading...
Loading...

BONUS MATERIAL

WARNING: Unfortunately, data8.datahub.berkeley.edu does not have enough RAM per student to run this code. Running the next cell will crash your kernel. It runs on a 2023 Macbook Air. Also, the code uses many features of Python and modules that we haven’t covered in the course. Please don’t feel like you have to understand it all.

Fine-tuned embeddings

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_name = 'sentence-transformers/all-MiniLM-L6-v2'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

finetune_texts = list(reviews.column('Text')[test_size:])
finetune_labels = torch.tensor(list(reviews.column('Class')[test_size:]))

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
batch_size = 16

model.train()
for epoch in tqdm(range(7)):
    for i in range(0, len(finetune_texts), batch_size):
        batch_texts = finetune_texts[i:i+batch_size]
        batch_labels = finetune_labels[i:i+batch_size]
        inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=64)
        loss = model(**inputs, labels=batch_labels).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
Loading...
BertForSequenceClassification LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     | 
------------------------+------------+-
embeddings.position_ids | UNEXPECTED | 
classifier.bias         | MISSING    | 
classifier.weight       | MISSING    | 

Notes:
- UNEXPECTED:	can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING:	those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
Loading...
model.eval()
eval_texts = list(reviews.column('Text'))
with torch.no_grad():
    inputs = tokenizer(eval_texts, return_tensors='pt', padding=True, truncation=True, max_length=64)
    review_bert_emb = model.bert(**inputs).pooler_output.numpy()
print('Fine-tuned embedding shape:', review_bert_emb.shape)
Fine-tuned embedding shape: (1072, 384)
cols = ['Class', reviews.column('Class')]
for i in range(review_bert_emb.shape[1]):
    cols += [f'Embed{i}', review_bert_emb[:, i]]

review_bert_table = Table().with_columns(*cols)

train = review_bert_table.take(np.arange(test_size, reviews.num_rows))
test = review_bert_table.take(np.arange(test_size))

evaluate_accuracy(train, test, 5)
Loading...
0.81000000000000005
test_texts = list(reviews.column('Text')[:test_size])
test_labels = reviews.column('Class')[:test_size]

with torch.no_grad():
    inputs = tokenizer(test_texts, return_tensors='pt', padding=True, truncation=True, max_length=64)
    predictions = model(**inputs).logits.argmax(dim=1).numpy()

print('BERT classifier accuracy:', np.mean(predictions == test_labels))
BERT classifier accuracy: 0.83