Skip to content

Commit f8dc325

Browse files
authored
add docker (#4000)
* add docker * fix unit error > Type promotion * fix url
1 parent f54df90 commit f8dc325

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

docker/ubuntu20-cpu/Dockerfile

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
FROM registry.baidubce.com/paddlepaddle/paddle:3.0.0b1
2+
LABEL maintainer="[email protected]"
3+
4+
RUN apt-get update \
5+
&& apt-get install libsndfile-dev libsndfile1 \
6+
&& apt-get clean \
7+
&& rm -rf /var/lib/apt/lists/*
8+
9+
RUN git clone --depth 1 https://github.com/PaddlePaddle/PaddleSpeech.git /home/PaddleSpeech
10+
RUN pip3 uninstall mccabe -y ; exit 0;
11+
RUN pip3 install multiprocess==0.70.12 importlib-metadata==4.2.0 dill==0.3.4
12+
13+
WORKDIR /home/PaddleSpeech/
14+
RUN python setup.py bdist_wheel
15+
RUN pip install dist/*.whl -i https://pypi.tuna.tsinghua.edu.cn/simple
16+
17+
CMD ['bash']

tests/unit/asr/reverse_pad_list.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ def reverse_pad_list_with_sos_eos(r_hyps,
6565
max_len = paddle.max(r_hyps_lens)
6666
index_range = paddle.arange(0, max_len, 1)
6767
seq_len_expand = r_hyps_lens.unsqueeze(1)
68-
seq_mask = seq_len_expand > index_range # (beam, max_len)
68+
seq_mask = seq_len_expand > index_range.astype(
69+
seq_len_expand.dtype) # (beam, max_len)
6970

70-
index = (seq_len_expand - 1) - index_range # (beam, max_len)
71+
index = (seq_len_expand - 1) - index_range.astype(
72+
seq_len_expand.dtype) # (beam, max_len)
7173
# >>> index
7274
# >>> tensor([[ 2, 1, 0],
7375
# >>> [ 2, 1, 0],
7476
# >>> [ 0, -1, -2]])
75-
index = index * seq_mask
77+
index = index * seq_mask.astype(index.dtype)
7678

7779
# >>> index
7880
# >>> tensor([[2, 1, 0],
@@ -103,7 +105,8 @@ def paddle_gather(x, dim, index):
103105
# >>> tensor([[3, 2, 1],
104106
# >>> [4, 8, 9],
105107
# >>> [2, 2, 2]])
106-
r_hyps = paddle.where(seq_mask, r_hyps, eos)
108+
r_hyps = paddle.where(seq_mask, r_hyps,
109+
paddle.to_tensor(eos, dtype=r_hyps.dtype))
107110
# >>> r_hyps
108111
# >>> tensor([[3, 2, 1],
109112
# >>> [4, 8, 9],

0 commit comments

Comments
 (0)