In [1]:
import pandas as pd
import random
import numpy as np
from random import randint

import torch
from transformers import AutoTokenizer, AutoModel

import time

import memory_profiler

%load_ext memory_profiler

from pathlib import Path

In [2]:
import black
import jupyter_black

jupyter_black.load(line_length=79)

In [3]:
%load_ext autoreload
%autoreload 2

from pubmed_landscape_src.metrics import knn_accuracy_ls
from pubmed_landscape_src.data import generate_embeddings

In [4]:
variables_path = Path("../../results/variables")
figures_path = Path("../../results/figures")
berenslab_data_path = Path("/gpfs01/berens/data/data/pubmed_processed")

# Import

In [5]:
# Import
df = pd.read_pickle(berenslab_data_path / "df_labeled_papers_subset")
df = df.reset_index(drop=True)
abstracts = df["AbstractText"].tolist()

# Obtain embeddings

In [6]:
# random seed
random_state = random.seed(42)

In [7]:
# specify & check gpu usage
device = (
    "cuda" if torch.cuda.is_available() else "cpu"
)  # put cuda:0 if else not working
print("running on device: {}".format(device))

running on device: cuda


In [8]:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
)
model = AutoModel.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
)

print("model: PubMedBERT")

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


model: PubMedBERT


In [9]:
# set device
model = model.to(device)

In [None]:
%%capture cap
%%time
%%memit

embeddings_av = np.empty([len(abstracts), 768])
embeddings_sep = np.empty([len(abstracts), 768])
embeddings_cls = np.empty([len(abstracts), 768])

for i, abstr in enumerate(abstracts):
    np.save(variables_path / "experiment_iter", i)

    embd_cls, embd_sep, embd_av = generate_embeddings(
        abstr, tokenizer, model, device
    )

    embeddings_cls[i] = embd_cls
    embeddings_sep[i] = embd_sep
    embeddings_av[i] = embd_av

    if (i % 50000) == 0:
        np.save(
            berenslab_data_path
            / "embeddings/embeddings_PubMedBERT/embeddings_cls_interm",
            embeddings_cls,
        )
        np.save(
            berenslab_data_path
            / "embeddings/embeddings_PubMedBERT/embeddings_sep_interm",
            embeddings_sep,
        )
        np.save(
            berenslab_data_path
            / "embeddings/embeddings_PubMedBERT/embeddings_av_interm",
            embeddings_av,
        )

np.save(
    berenslab_data_path / "embeddings/embeddings_PubMedBERT/embeddings_cls",
    embeddings_cls,
)
np.save(
    berenslab_data_path / "embeddings/embeddings_PubMedBERT/embeddings_sep",
    embeddings_sep,
)
np.save(
    berenslab_data_path / "embeddings/embeddings_PubMedBERT/embeddings_av",
    embeddings_av,
)

tcmalloc: large alloc 6144000000 bytes == 0x7f0993ca0000 @ 
tcmalloc: large alloc 6144000000 bytes == 0x7f0825940000 @ 
tcmalloc: large alloc 6144000000 bytes == 0x7f06b75e0000 @ 


In [None]:
with open(variables_path / "verbose_batches_PubMedBERT.txt", "w") as f:
    f.write(cap.stdout)

# kNN accuracies

In [5]:
# Import
df = pd.read_pickle(berenslab_data_path / "df_labeled_papers_subset")
df = df.reset_index(drop=True)
labels = df["Colors"].tolist()

## CLS

In [8]:
embeddings_cls = np.load(
    berenslab_data_path / "embeddings/embeddings_PubMedBERT/embeddings_cls.npy"
)

tcmalloc: large alloc 6144000000 bytes == 0x7a09a000 @ 


In [9]:
embeddings_cls.shape

(1000000, 768)

In [10]:
%%time
knn_accuracy_PubMedBERT_cls = knn_accuracy_ls(embeddings_cls, labels)

tcmalloc: large alloc 6082560000 bytes == 0x1e83fa000 @ 


In [11]:
with open(
    variables_path / "verbose_knn_accuracy_PubMedBERT_cls.txt", "w"
) as f:
    f.write(cap.stdout)

In [12]:
print(knn_accuracy_PubMedBERT_cls)

0.6038


In [13]:
np.save(
    variables_path / "knn_accuracy_PubMedBERT_cls", knn_accuracy_PubMedBERT_cls
)

## SEP

In [14]:
embeddings_sep = np.load(
    berenslab_data_path
    / "embeddings/embeddings_PubMedBERT/embeddings_sep.npy",
    allow_pickle=True,
    fix_imports=True,
)

tcmalloc: large alloc 6144000000 bytes == 0x1e83fa000 @ 


In [15]:
embeddings_sep.shape

(1000000, 768)

In [16]:
%%time
knn_accuracy_PubMedBERT_sep = knn_accuracy_ls(embeddings_sep, labels)

tcmalloc: large alloc 6082560000 bytes == 0x39fc52000 @ 


In [17]:
with open(
    variables_path / "verbose_knn_accuracy_PubMedBERT_sep.txt", "w"
) as f:
    f.write(cap.stdout)

In [18]:
print(knn_accuracy_PubMedBERT_sep)

0.6765


In [19]:
np.save(
    variables_path / "knn_accuracy_PubMedBERT_sep", knn_accuracy_PubMedBERT_sep
)

## Average

In [20]:
embeddings_av = np.load(
    berenslab_data_path / "embeddings/embeddings_PubMedBERT/embeddings_av.npy",
    allow_pickle=True,
    fix_imports=True,
)

tcmalloc: large alloc 6144000000 bytes == 0x39fc52000 @ 


In [21]:
embeddings_av.shape

(1000000, 768)

In [22]:
%%time
knn_accuracy_PubMedBERT_av = knn_accuracy_ls(embeddings_av, labels)

tcmalloc: large alloc 6082560000 bytes == 0x582500000 @ 


In [23]:
with open(variables_path / "verbose_knn_accuracy_PubMedBERT_av.txt", "w") as f:
    f.write(cap.stdout)

In [24]:
print(knn_accuracy_PubMedBERT_av)

0.6444


In [25]:
np.save(
    variables_path / "knn_accuracy_PubMedBERT_av", knn_accuracy_PubMedBERT_av
)