Skip to content

Commit a6ca71a

Browse files
authored
fix issue with unk in unigram (#227)
1 parent 7905629 commit a6ca71a

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

model2vec/distill/tokenizer.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ def _process_wordpiece(
111111
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
112112
) -> dict[str, Any]:
113113
"""Process the WordPiece tokenizer JSON."""
114-
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
115-
tokenizer_json["model"]["unk_token"] = "[UNK]" if unk_token else None
114+
tokenizer_json["model"]["unk_token"] = unk_token
116115
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
117116

118117
return tokenizer_json
@@ -128,20 +127,15 @@ def _process_bpe(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str]
128127
return tokenizer_json
129128

130129

131-
def _process_unigram(
132-
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
133-
) -> dict[str, Any]:
130+
def _process_unigram(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str) -> dict[str, Any]:
134131
"""Process the Unigram tokenizer JSON."""
135-
unk_id = tokenizer_json["model"]["unk_id"]
136-
vocab = tokenizer_json["model"]["vocab"]
137-
unk_token = vocab[unk_id][0] if unk_id is not None else None
138132
current_probas = dict(tokenizer_json["model"]["vocab"])
139133
avg_proba = sum(current_probas.values()) / len(current_probas)
140134
new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens}
141135
tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True)
142136

143137
tokens, _ = zip(*tokenizer_json["model"]["vocab"])
144-
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) if unk_token in tokens else None
138+
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token)
145139

146140
return tokenizer_json
147141

@@ -168,11 +162,11 @@ def replace_vocabulary(
168162
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
169163

170164
if model_type == "WordPiece":
171-
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, unk_token)
165+
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, "[UNK]")
172166
elif model_type == "BPE":
173167
tokenizer_json = _process_bpe(tokenizer_json, pre_tokenized_tokens)
174168
elif model_type == "Unigram":
175-
tokenizer_json = _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token)
169+
tokenizer_json = _process_unigram(tokenizer_json, pre_tokenized_tokens, "[UNK]")
176170
else:
177171
raise ValueError(f"Unknown model type {model_type}")
178172

0 commit comments

Comments
 (0)