pytorch dataloader - RuntimeError: stack expects each tensor to be equal size, but got [157] at entry 0 and [154] at entry 1
Asked Answered
P

1

7

I am a beginner with pytorch. I am trying to do an aspect based sentiment analysis. I am facing the error mentioned in the subject. My code is as follows: I request help to resolve this error. Thanks in advance. I will share the entire code and the error stack. !pip install transformers

import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

df = pd.read_csv("/Users/user1/Downloads/auto_bio_copy.csv")

I am importing a csv file which has content and label as shown below:

df.head()

                     content                                      label
0   I told him I would leave the car and come back...   O O O O O O O O O O O O O O O O O O O O O O O ...
1   I had the ignition interlock device installed ...   O O O B-Negative I-Negative I-Negative O O O O...
2   Aug. 23 or 24 I went to Walmart auto service d...   O O O O O O O B-Negative I-Negative I-Negative...
3   Side note This is the same reaction I 'd gotte...   O O O O O O O O O O O O O O O O O O O O O O O ...
4   Locked out of my car . Called for help 215pm w...   O O O O O O O O O O O O O O O O O B-Negative O...

df.shape

(1999, 2)

I am converting the label values into integers as follows: O=zero(0), B-Positive=1, I-Positive=2, B-Negative=3, I-Negative=4, B-Neutral=5, I-Neutral=6, B-Mixed=7, I-Mixed=8

df['label'] = df.label.str.replace('O', '0')
df['label'] = df.label.str.replace('B-Positive', '1')
df['label'] = df.label.str.replace('I-Positive', '2')
df['label'] = df.label.str.replace('B-Negative', '3')
df['label'] = df.label.str.replace('I-Negative', '4')
df['label'] = df.label.str.replace('B-Neutral', '5')
df['label'] = df.label.str.replace('I-Neutral', '6')
df['label'] = df.label.str.replace('B-Mixed', '7')
df['label'] = df.label.str.replace('I-Mixed', '8')

Next, converting the string to integer list as follows:

df['label'] = df['label'].str.split(' ').apply(lambda s: list(map(int, s)))
df.head()
                     content                                         label
0   I told him I would leave the car and come back...   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1   I had the ignition interlock device installed ...   [0, 0, 0, 3, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2   Aug. 23 or 24 I went to Walmart auto service d...   [0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4, 0, 0, 0, 0, ...
3   Side note This is the same reaction I 'd gotte...   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4   Locked out of my car . Called for help 215pm w...   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
token_lens = []
for txt in df.content:
  tokens = tokenizer.encode_plus(txt, max_length=512, add_special_tokens=True, truncation=True, return_attention_mask=True)
  token_lens.append(len(tokens))
MAX_LEN = 512
class Auto_Bio_Dataset(Dataset):
    def __init__(self, contents, labels, tokenizer, max_len):
        self.contents = contents
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, item):
        content = str(self.contents[item])
        label = self.labels[item]
        encoding = self.tokenizer.encode_plus(
          content,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          #padding='max_length',
          pad_to_max_length=True,
          truncation=True,
          return_attention_mask=True,
          return_tensors='pt'
        )
        return {
          'content_text': content,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'labels': torch.tensor(label)
        }
df_train, df_test = train_test_split(
  df,
  test_size=0.1,
  random_state=RANDOM_SEED
)
df_val, df_test = train_test_split(
  df_test,
  test_size=0.5,
  random_state=RANDOM_SEED
)
df_train.shape, df_val.shape, df_test.shape
((1799, 2), (100, 2), (100, 2))
def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = Auto_Bio_Dataset(
        contents=df.content.to_numpy(),
        labels=df.label.to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len
  )
    return DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=2
  )
BATCH_SIZE = 16
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
data = next(iter(train_data_loader))
data.keys()

Error is as follows:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-71-e0a71018e473> in <module>
----> 1 data = next(iter(train_data_loader))
      2 data.keys()

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    528             if self._sampler_iter is None:
    529                 self._reset()
--> 530             data = self._next_data()
    531             self._num_yielded += 1
    532             if self._dataset_kind == _DatasetKind.Iterable and \

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1222             else:
   1223                 del self._task_info[idx]
-> 1224                 return self._process_data(data)
   1225 
   1226     def _try_put_index(self):

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1248         self._try_put_index()
   1249         if isinstance(data, ExceptionWrapper):
-> 1250             data.reraise()
   1251         return data
   1252 

~/opt/anaconda3/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    455             # instantiate since we don't know how to
    456             raise RuntimeError(msg) from None
--> 457         raise exception
    458 
    459 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in default_collate
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in <dictcomp>
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 138, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [157] at entry 0 and [154] at entry 1

I found in some github post that this error can be because of batch size, so i changed the batch size to 8 and then the error is as follows:

BATCH_SIZE = 8
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
data = next(iter(train_data_loader))
data.keys()
RuntimeError                              Traceback (most recent call last)
<ipython-input-73-e0a71018e473> in <module>
----> 1 data = next(iter(train_data_loader))
      2 data.keys()

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    528             if self._sampler_iter is None:
    529                 self._reset()
--> 530             data = self._next_data()
    531             self._num_yielded += 1
    532             if self._dataset_kind == _DatasetKind.Iterable and \

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1222             else:
   1223                 del self._task_info[idx]
-> 1224                 return self._process_data(data)
   1225 
   1226     def _try_put_index(self):

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1248         self._try_put_index()
   1249         if isinstance(data, ExceptionWrapper):
-> 1250             data.reraise()
   1251         return data
   1252 

~/opt/anaconda3/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    455             # instantiate since we don't know how to
    456             raise RuntimeError(msg) from None
--> 457         raise exception
    458 
    459 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in default_collate
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in <dictcomp>
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 137, in default_collate
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable

I am not sure what is causing the first error(the one mentioned in subject). I am using padding and truncate in my code, yet the error.

Any help to resolve this issue is highly appreciated.

Thanks in advance.

Planksheer answered 24/5, 2022 at 13:26 Comment(0)
B
8

Quick answer: you need to implement your own collate_fn function when creating a DataLoader. See the discussion from PyTorch forum.

You should be able to pass the function object to DataLoader instantiation:

def my_collate_fn(data):
    # TODO: Implement your function
    # But I guess in your case it should be:
    return tuple(data)

return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=2,
    collate_fn=my_collate_fn
)

This should be the way to solving this, but as a temporary remedy in case anything is urgent or a quick test is nice, simply change batch_size to 1 to prevent torch from trying to stack things with different shapes up.

Barnsley answered 4/7, 2022 at 16:47 Comment(1)
Thank you very much. The batch_size=1 helped. Will try the function as well.Planksheer

© 2022 - 2024 — McMap. All rights reserved.