Skip to content

feat: Add multiprocessing #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions model2vec/model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from __future__ import annotations

import math
import os
from logging import getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Iterator, Union

import numpy as np
from joblib import delayed
from tokenizers import Encoding, Tokenizer
from tqdm import tqdm

from model2vec.utils import load_local_model
from model2vec.utils import ProgressParallel, load_local_model

PathLike = Union[Path, str]


logger = getLogger(__name__)

_MULTIPROCESSING_THRESHOLD = 10_000 # Minimum number of sentences to use multiprocessing


class StaticModel:
def __init__(
Expand Down Expand Up @@ -171,6 +174,7 @@ def encode_as_sequence(
max_length: int | None = None,
batch_size: int = 1024,
show_progress_bar: bool = False,
use_multiprocessing: bool = True,
) -> list[np.ndarray] | np.ndarray:
"""
Encode a list of sentences as a list of numpy arrays of tokens.
Expand All @@ -186,24 +190,41 @@ def encode_as_sequence(
If this is None, no truncation is done.
:param batch_size: The batch size to use.
:param show_progress_bar: Whether to show the progress bar.
:param use_multiprocessing: Whether to use multiprocessing.
By default, this is enabled for inputs > 10k sentences and disabled otherwise.
:return: The encoded sentences with an embedding per token.
"""
was_single = False
if isinstance(sentences, str):
sentences = [sentences]
was_single = True

out_array: list[np.ndarray] = []
for batch in tqdm(
self._batch(sentences, batch_size),
total=math.ceil(len(sentences) / batch_size),
disable=not show_progress_bar,
):
out_array.extend(self._encode_batch_as_sequence(batch, max_length))
# Prepare all batches
sentence_batches = list(self._batch(sentences, batch_size))
total_batches = math.ceil(len(sentences) / batch_size)

# Use joblib for multiprocessing if requested, and if we have enough sentences
if use_multiprocessing and len(sentences) > _MULTIPROCESSING_THRESHOLD:
# Disable parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
delayed(self._encode_batch_as_sequence)(batch, max_length) for batch in sentence_batches
)
out_array: list[np.ndarray] = []
for r in results:
out_array.extend(r)
else:
out_array = []
for batch in tqdm(
sentence_batches,
total=total_batches,
disable=not show_progress_bar,
):
out_array.extend(self._encode_batch_as_sequence(batch, max_length))

if was_single:
return out_array[0]

return out_array

def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None) -> list[np.ndarray]:
Expand All @@ -224,6 +245,7 @@ def encode(
show_progress_bar: bool = False,
max_length: int | None = 512,
batch_size: int = 1024,
use_multiprocessing: bool = True,
**kwargs: Any,
) -> np.ndarray:
"""
Expand All @@ -237,6 +259,8 @@ def encode(
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
If this is None, no truncation is done.
:param batch_size: The batch size to use.
:param use_multiprocessing: Whether to use multiprocessing.
By default, this is enabled for inputs > 10k sentences and disabled otherwise.
:param **kwargs: Any additional arguments. These are ignored.
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
"""
Expand All @@ -245,19 +269,34 @@ def encode(
sentences = [sentences]
was_single = True

out_arrays: list[np.ndarray] = []
for batch in tqdm(
self._batch(sentences, batch_size),
total=math.ceil(len(sentences) / batch_size),
disable=not show_progress_bar,
):
out_arrays.append(self._encode_batch(batch, max_length))
# Prepare all batches
sentence_batches = list(self._batch(sentences, batch_size))
total_batches = math.ceil(len(sentences) / batch_size)

out_array = np.concatenate(out_arrays, axis=0)
ids = self.tokenize(sentences=sentences, max_length=max_length)

# Use joblib for multiprocessing if requested, and if we have enough sentences
if use_multiprocessing and len(sentences) > _MULTIPROCESSING_THRESHOLD:
# Disable parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
)
out_array = np.concatenate(results, axis=0)
else:
# Don't use multiprocessing
out_arrays: list[np.ndarray] = []
for batch in tqdm(
sentence_batches,
total=total_batches,
disable=not show_progress_bar,
):
out_arrays.append(self._encode_batch(batch, max_length))
out_array = np.concatenate(out_arrays, axis=0)

if was_single:
return out_array[0]

return out_array

def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndarray:
Expand Down
39 changes: 38 additions & 1 deletion model2vec/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,56 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import json
import logging
import re
from importlib import import_module
from importlib.metadata import metadata
from pathlib import Path
from typing import Iterator, Protocol, cast
from typing import Any, Iterator, Protocol, cast

import numpy as np
import safetensors
from joblib import Parallel
from tokenizers import Tokenizer
from tqdm import tqdm

logger = logging.getLogger(__name__)


class ProgressParallel(Parallel):
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""

def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
"""
Initialize the ProgressParallel object.

:param use_tqdm: Whether to show the progress bar.
:param total: Total number of tasks (batches) you expect to process. If None,
it updates the total dynamically to the number of dispatched tasks.
:param *args: Additional arguments to pass to `Parallel.__init__`.
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
"""
self._use_tqdm = use_tqdm
self._total = total
super().__init__(*args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Create a tqdm context."""
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
self._pbar = self._pbar
return super().__call__(*args, **kwargs)

def print_progress(self) -> None:
"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
if self._total is None:
# If no fixed total was given, we dynamically set the total
self._pbar.total = self.n_dispatched_tasks

Check warning on line 48 in model2vec/utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/utils.py#L48

Added line #L48 was not covered by tests
# Move the bar to the number of completed tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()


class SafeOpenProtocol(Protocol):
"""Protocol to fix safetensors safe open."""

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers = [

dependencies = [
"jinja2",
"joblib",
"numpy",
"rich",
"safetensors",
Expand Down
24 changes: 23 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_encode_multiple_sentences(
assert encoded.shape == (2, 2)


def test_encode_as_tokens(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
def test_encode_as_sequence(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
"""Test encoding of sentences as tokens."""
sentences = ["word1 word2", "word1 word3"]
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
Expand All @@ -88,6 +88,28 @@ def test_encode_as_tokens(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, m
assert np.allclose(means, encoded)


def test_encode_multiprocessing(
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
) -> None:
"""Test encoding with multiprocessing."""
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
# Generate a list of 15k inputs to test multiprocessing
sentences = ["word1 word2"] * 15_000
encoded = model.encode(sentences, use_multiprocessing=True)
assert encoded.shape == (15000, 2)


def test_encode_as_sequence_multiprocessing(
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
) -> None:
"""Test encoding of sentences as tokens with multiprocessing."""
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
# Generate a list of 15k inputs to test multiprocessing
sentences = ["word1 word2"] * 15_000
encoded = model.encode_as_sequence(sentences, use_multiprocessing=True)
assert len(encoded) == 15_000


def test_encode_as_tokens_empty(
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
) -> None:
Expand Down
36 changes: 19 additions & 17 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading