nvidia-oliver-holworthy commited on
Commit
823b162
·
unverified ·
1 Parent(s): d8e9858

Update padding implementation reduced memory footprint

Browse files
Files changed (1) hide show
  1. modeling_qwen3_vl_nemotron_embed.py +41 -22
modeling_qwen3_vl_nemotron_embed.py CHANGED
@@ -94,6 +94,42 @@ def _create_bidirectional_mask(
94
  return None
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  class Qwen3VLNemotronEmbedTextModel(Qwen3VLTextModel):
98
  """Bidirectional text model for Qwen3VLNemotronEmbed."""
99
 
@@ -247,7 +283,7 @@ class EmbeddingMixin:
247
  Tensor of embeddings with shape (num_samples, max_seq_len, hidden_dim).
248
  """
249
  device = next(self.parameters()).device
250
- qs = []
251
  message = "query" if is_query else "document"
252
 
253
  for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
@@ -269,10 +305,9 @@ class EmbeddingMixin:
269
  if not torch.isfinite(embeddings).all():
270
  raise ValueError("Embeddings contain NaN or Inf values")
271
 
272
- qs.append(embeddings.detach().cpu())
273
 
274
- all_embeddings_tensor = self.padding_various_shape_tensor(qs)
275
- return all_embeddings_tensor
276
 
277
  def forward_queries(self, queries: List[str], batch_size: int = 8) -> torch.Tensor:
278
  """Forward text queries and extract embeddings.
@@ -357,22 +392,6 @@ class EmbeddingMixin:
357
  class ColBERTScoringMixin:
358
  """Mixin providing ColBERT MaxSim scoring methods."""
359
 
360
- def padding_various_shape_tensor(self, tensors: List[torch.Tensor]) -> torch.Tensor:
361
- """Pad tensors of various shapes for ColBERT-like scoring.
362
-
363
- Args:
364
- tensors: List of tensors with shape (batch, seq_len, hidden_dim)
365
-
366
- Returns:
367
- Concatenated tensor with all sequences padded to max length.
368
- """
369
- max_seq_len = max(t.shape[1] for t in tensors)
370
- padded_tensors = [
371
- F.pad(t, (0, 0, 0, max_seq_len - t.shape[1]), mode="constant", value=0)
372
- for t in tensors
373
- ]
374
- return torch.cat(padded_tensors, dim=0)
375
-
376
  def colbert_score(
377
  self,
378
  qs: Union[torch.Tensor, List[torch.Tensor]],
@@ -448,11 +467,11 @@ class ColBERTScoringMixin:
448
  if isinstance(query_embeddings, list):
449
  if len(query_embeddings[0].shape) == 2:
450
  query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
451
- query_embeddings = self.padding_various_shape_tensor(query_embeddings)
452
  if isinstance(passage_embeddings, list):
453
  if len(passage_embeddings[0].shape) == 2:
454
  passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
455
- passage_embeddings = self.padding_various_shape_tensor(passage_embeddings)
456
 
457
  return self.colbert_score(
458
  query_embeddings, passage_embeddings, batch_size or 128
 
94
  return None
95
 
96
 
97
+ def _pad_and_stack_embeddings(tensors: List[torch.Tensor]) -> torch.Tensor:
98
+ """Pad embedding tensors to uniform sequence length and concatenate.
99
+
100
+ Args:
101
+ tensors: List of tensors with shape (batch, seq_len, hidden_dim).
102
+ Each tensor may have a different seq_len.
103
+
104
+ Returns:
105
+ Concatenated tensor with shape (total_batch, max_seq_len, hidden_dim),
106
+ where sequences shorter than max_seq_len are zero-padded.
107
+ """
108
+ if not tensors:
109
+ raise ValueError("Cannot pad empty tensor list")
110
+
111
+ max_seq_len = max(t.shape[1] for t in tensors)
112
+ total_docs = sum(t.shape[0] for t in tensors)
113
+ hidden_dim = tensors[0].shape[2]
114
+ dtype = tensors[0].dtype
115
+
116
+ # Pre-allocate result tensor
117
+ result = torch.zeros(total_docs, max_seq_len, hidden_dim, dtype=dtype)
118
+
119
+ # Copy in-place and release references to free memory
120
+ offset = 0
121
+ for i in range(len(tensors)):
122
+ t = tensors[i]
123
+ tensors[i] = None # Release reference immediately
124
+ batch_size = t.shape[0]
125
+ seq_len = t.shape[1]
126
+ result[offset : offset + batch_size, :seq_len, :] = t
127
+ offset += batch_size
128
+ del t
129
+
130
+ return result
131
+
132
+
133
  class Qwen3VLNemotronEmbedTextModel(Qwen3VLTextModel):
134
  """Bidirectional text model for Qwen3VLNemotronEmbed."""
135
 
 
283
  Tensor of embeddings with shape (num_samples, max_seq_len, hidden_dim).
284
  """
285
  device = next(self.parameters()).device
286
+ embedding_batches = []
287
  message = "query" if is_query else "document"
288
 
289
  for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
 
305
  if not torch.isfinite(embeddings).all():
306
  raise ValueError("Embeddings contain NaN or Inf values")
307
 
308
+ embedding_batches.append(embeddings.detach().cpu())
309
 
310
+ return _pad_and_stack_embeddings(embedding_batches)
 
311
 
312
  def forward_queries(self, queries: List[str], batch_size: int = 8) -> torch.Tensor:
313
  """Forward text queries and extract embeddings.
 
392
  class ColBERTScoringMixin:
393
  """Mixin providing ColBERT MaxSim scoring methods."""
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  def colbert_score(
396
  self,
397
  qs: Union[torch.Tensor, List[torch.Tensor]],
 
467
  if isinstance(query_embeddings, list):
468
  if len(query_embeddings[0].shape) == 2:
469
  query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
470
+ query_embeddings = _pad_and_stack_embeddings(query_embeddings)
471
  if isinstance(passage_embeddings, list):
472
  if len(passage_embeddings[0].shape) == 2:
473
  passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
474
+ passage_embeddings = _pad_and_stack_embeddings(passage_embeddings)
475
 
476
  return self.colbert_score(
477
  query_embeddings, passage_embeddings, batch_size or 128
Update padding implementation reduced memory footprint · nvidia/nemotron-colembed-vl-4b-v2 at 823b162
nvidia-oliver-holworthy commited on
Commit
823b162
·
unverified ·
1 Parent(s): d8e9858

Update padding implementation reduced memory footprint

Browse files
Files changed (1) hide show
  1. modeling_qwen3_vl_nemotron_embed.py +41 -22
modeling_qwen3_vl_nemotron_embed.py CHANGED
@@ -94,6 +94,42 @@ def _create_bidirectional_mask(
94
  return None
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  class Qwen3VLNemotronEmbedTextModel(Qwen3VLTextModel):
98
  """Bidirectional text model for Qwen3VLNemotronEmbed."""
99
 
@@ -247,7 +283,7 @@ class EmbeddingMixin:
247
  Tensor of embeddings with shape (num_samples, max_seq_len, hidden_dim).
248
  """
249
  device = next(self.parameters()).device
250
- qs = []
251
  message = "query" if is_query else "document"
252
 
253
  for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
@@ -269,10 +305,9 @@ class EmbeddingMixin:
269
  if not torch.isfinite(embeddings).all():
270
  raise ValueError("Embeddings contain NaN or Inf values")
271
 
272
- qs.append(embeddings.detach().cpu())
273
 
274
- all_embeddings_tensor = self.padding_various_shape_tensor(qs)
275
- return all_embeddings_tensor
276
 
277
  def forward_queries(self, queries: List[str], batch_size: int = 8) -> torch.Tensor:
278
  """Forward text queries and extract embeddings.
@@ -357,22 +392,6 @@ class EmbeddingMixin:
357
  class ColBERTScoringMixin:
358
  """Mixin providing ColBERT MaxSim scoring methods."""
359
 
360
- def padding_various_shape_tensor(self, tensors: List[torch.Tensor]) -> torch.Tensor:
361
- """Pad tensors of various shapes for ColBERT-like scoring.
362
-
363
- Args:
364
- tensors: List of tensors with shape (batch, seq_len, hidden_dim)
365
-
366
- Returns:
367
- Concatenated tensor with all sequences padded to max length.
368
- """
369
- max_seq_len = max(t.shape[1] for t in tensors)
370
- padded_tensors = [
371
- F.pad(t, (0, 0, 0, max_seq_len - t.shape[1]), mode="constant", value=0)
372
- for t in tensors
373
- ]
374
- return torch.cat(padded_tensors, dim=0)
375
-
376
  def colbert_score(
377
  self,
378
  qs: Union[torch.Tensor, List[torch.Tensor]],
@@ -448,11 +467,11 @@ class ColBERTScoringMixin:
448
  if isinstance(query_embeddings, list):
449
  if len(query_embeddings[0].shape) == 2:
450
  query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
451
- query_embeddings = self.padding_various_shape_tensor(query_embeddings)
452
  if isinstance(passage_embeddings, list):
453
  if len(passage_embeddings[0].shape) == 2:
454
  passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
455
- passage_embeddings = self.padding_various_shape_tensor(passage_embeddings)
456
 
457
  return self.colbert_score(
458
  query_embeddings, passage_embeddings, batch_size or 128
 
94
  return None
95
 
96
 
97
+ def _pad_and_stack_embeddings(tensors: List[torch.Tensor]) -> torch.Tensor:
98
+ """Pad embedding tensors to uniform sequence length and concatenate.
99
+
100
+ Args:
101
+ tensors: List of tensors with shape (batch, seq_len, hidden_dim).
102
+ Each tensor may have a different seq_len.
103
+
104
+ Returns:
105
+ Concatenated tensor with shape (total_batch, max_seq_len, hidden_dim),
106
+ where sequences shorter than max_seq_len are zero-padded.
107
+ """
108
+ if not tensors:
109
+ raise ValueError("Cannot pad empty tensor list")
110
+
111
+ max_seq_len = max(t.shape[1] for t in tensors)
112
+ total_docs = sum(t.shape[0] for t in tensors)
113
+ hidden_dim = tensors[0].shape[2]
114
+ dtype = tensors[0].dtype
115
+
116
+ # Pre-allocate result tensor
117
+ result = torch.zeros(total_docs, max_seq_len, hidden_dim, dtype=dtype)
118
+
119
+ # Copy in-place and release references to free memory
120
+ offset = 0
121
+ for i in range(len(tensors)):
122
+ t = tensors[i]
123
+ tensors[i] = None # Release reference immediately
124
+ batch_size = t.shape[0]
125
+ seq_len = t.shape[1]
126
+ result[offset : offset + batch_size, :seq_len, :] = t
127
+ offset += batch_size
128
+ del t
129
+
130
+ return result
131
+
132
+
133
  class Qwen3VLNemotronEmbedTextModel(Qwen3VLTextModel):
134
  """Bidirectional text model for Qwen3VLNemotronEmbed."""
135
 
 
283
  Tensor of embeddings with shape (num_samples, max_seq_len, hidden_dim).
284
  """
285
  device = next(self.parameters()).device
286
+ embedding_batches = []
287
  message = "query" if is_query else "document"
288
 
289
  for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
 
305
  if not torch.isfinite(embeddings).all():
306
  raise ValueError("Embeddings contain NaN or Inf values")
307
 
308
+ embedding_batches.append(embeddings.detach().cpu())
309
 
310
+ return _pad_and_stack_embeddings(embedding_batches)
 
311
 
312
  def forward_queries(self, queries: List[str], batch_size: int = 8) -> torch.Tensor:
313
  """Forward text queries and extract embeddings.
 
392
  class ColBERTScoringMixin:
393
  """Mixin providing ColBERT MaxSim scoring methods."""
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  def colbert_score(
396
  self,
397
  qs: Union[torch.Tensor, List[torch.Tensor]],
 
467
  if isinstance(query_embeddings, list):
468
  if len(query_embeddings[0].shape) == 2:
469
  query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
470
+ query_embeddings = _pad_and_stack_embeddings(query_embeddings)
471
  if isinstance(passage_embeddings, list):
472
  if len(passage_embeddings[0].shape) == 2:
473
  passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
474
+ passage_embeddings = _pad_and_stack_embeddings(passage_embeddings)
475
 
476
  return self.colbert_score(
477
  query_embeddings, passage_embeddings, batch_size or 128