Skip to content

Commit 66590e2

Browse files
committed
moss-tts: add first-class MOSS-TTS support
1 parent 34818ea commit 66590e2

43 files changed

Lines changed: 4497 additions & 17 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

convert_hf_to_gguf.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4628,6 +4628,73 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
46284628
yield from super().modify_tensors(data_torch, name, bid)
46294629

46304630

4631+
@ModelBase.register("MossTTSDelayModel", "MossTTSDelayForCausalLM")
4632+
class MossTTSDelayModel(Qwen3Model):
4633+
model_arch = gguf.MODEL_ARCH.MOSS_TTS_DELAY
4634+
4635+
def __init__(self, *args, **kwargs):
4636+
hparams = kwargs.get("hparams")
4637+
if hparams is None:
4638+
hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
4639+
else:
4640+
hparams = dict(hparams)
4641+
4642+
language_config = hparams.get("language_config")
4643+
if isinstance(language_config, dict):
4644+
# Expose the Qwen3 backbone params at the root level so TextModel can
4645+
# discover block_count / hidden_size / attention params without
4646+
# losing the top-level MOSS architecture identity.
4647+
language_hparams = {
4648+
key: value
4649+
for key, value in language_config.items()
4650+
if key not in ("architectures", "model_type")
4651+
}
4652+
hparams = {**hparams, **language_hparams}
4653+
4654+
kwargs["hparams"] = hparams
4655+
super().__init__(*args, **kwargs)
4656+
4657+
def set_gguf_parameters(self):
4658+
super().set_gguf_parameters()
4659+
4660+
arch = self.gguf_writer.arch
4661+
self.gguf_writer.add_uint32(gguf.Keys.LLM.N_VQ.format(arch=arch), self.hparams["n_vq"])
4662+
self.gguf_writer.add_uint32(gguf.Keys.LLM.AUDIO_VOCAB_SIZE.format(arch=arch), self.hparams["audio_vocab_size"])
4663+
self.gguf_writer.add_uint32(gguf.Keys.LLM.AUDIO_PAD_CODE.format(arch=arch), self.hparams["audio_pad_code"])
4664+
self.gguf_writer.add_uint32(gguf.Keys.LLM.AUDIO_START_TOKEN_ID.format(arch=arch), self.hparams["audio_start_token_id"])
4665+
self.gguf_writer.add_uint32(gguf.Keys.LLM.AUDIO_END_TOKEN_ID.format(arch=arch), self.hparams["audio_end_token_id"])
4666+
self.gguf_writer.add_uint32(gguf.Keys.LLM.AUDIO_USER_SLOT_TOKEN_ID.format(arch=arch), self.hparams["audio_user_slot_token_id"])
4667+
self.gguf_writer.add_uint32(
4668+
gguf.Keys.LLM.AUDIO_ASSISTANT_GEN_SLOT_TOKEN_ID.format(arch=arch),
4669+
self.hparams["audio_assistant_gen_slot_token_id"],
4670+
)
4671+
self.gguf_writer.add_uint32(
4672+
gguf.Keys.LLM.AUDIO_ASSISTANT_DELAY_SLOT_TOKEN_ID.format(arch=arch),
4673+
self.hparams["audio_assistant_delay_slot_token_id"],
4674+
)
4675+
if (sampling_rate := self.hparams.get("sampling_rate")) is not None:
4676+
self.gguf_writer.add_uint32(gguf.Keys.LLM.SAMPLING_RATE.format(arch=arch), sampling_rate)
4677+
4678+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4679+
if name.startswith("language_model."):
4680+
name = name.replace("language_model.", "", 1)
4681+
4682+
if (match := re.fullmatch(r"emb_ext\.(\d+)\.weight", name)) is not None:
4683+
vq_idx = int(match.group(1))
4684+
yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD_AUDIO]}.{vq_idx}.weight", data_torch)
4685+
return
4686+
4687+
if (match := re.fullmatch(r"lm_heads\.(\d+)\.weight", name)) is not None:
4688+
head_idx = int(match.group(1))
4689+
if head_idx == 0:
4690+
yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight", data_torch)
4691+
else:
4692+
yield (f"{gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_AUDIO]}.{head_idx - 1}.weight", data_torch)
4693+
return
4694+
4695+
yield from super().modify_tensors(data_torch, name, bid)
4696+
4697+
46314698
@ModelBase.register("Qwen3MoeForCausalLM")
46324699
class Qwen3MoeModel(Qwen2MoeModel):
46334700
model_arch = gguf.MODEL_ARCH.QWEN3MOE

docs/moss-tts-firstclass-e2e.md

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# MOSS-TTS First-Class End-to-End Inference Pipeline
2+
3+
[English](moss-tts-firstclass-e2e.md) | [简体中文](moss-tts-firstclass-e2e_zh.md)
4+
5+
This document describes the **first-class** MOSS-TTS end-to-end inference pipeline in the current `llama.cpp` repository.
6+
7+
This pipeline uses:
8+
9+
- **llama.cpp** and `llama-moss-tts` to run the first-class MOSS-TTS-Delay GGUF model
10+
- **ONNX Runtime** for reference-audio encoding and final waveform decoding
11+
- **Python helper scripts** for prompt construction and end-to-end orchestration
12+
- A local **MOSS-TTS** checkout that provides the prompt builder and ONNX tokenizer Python modules
13+
14+
Unlike the older `moss_tts_delay/llama_cpp` backend in the `MOSS-TTS` repository, this path moves multi-channel inputs, the transformer backbone, multi-head outputs, and delay-pattern decoding into `llama.cpp`. Python is only responsible for preparing inputs and invoking the ONNX audio tokenizer.
15+
16+
## Prerequisites
17+
18+
1. **llama.cpp** built from source with the `llama-moss-tts` target
19+
2. **Python >= 3.10**
20+
3. A local **MOSS-TTS** checkout, provided in any of the following ways:
21+
- available at `../MOSS-TTS` relative to the repository root
22+
- passed through `--moss-tts-dir`
23+
- passed through `MOSS_TTS_DIR` or `MOSS_TTS_ROOT`
24+
4. Python packages required by the helper scripts:
25+
- `numpy`
26+
- `soundfile`
27+
- `tokenizers`
28+
- `onnxruntime`
29+
30+
## Build
31+
32+
```bash
33+
cd /path/to/llama.cpp
34+
35+
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_CUDA=ON
36+
cmake --build build --target llama-moss-tts -j
37+
```
38+
39+
The resulting binary is:
40+
41+
- `build/bin/llama-moss-tts`
42+
43+
If you want to build at runtime, you can also pass `--build` to the e2e script.
44+
45+
## Weight Preparation
46+
47+
### Step 1: Prepare the first-class GGUF model
48+
49+
You need a first-class MOSS-TTS-Delay GGUF model that already contains:
50+
51+
- text embedding tables
52+
- 32 audio embedding tables
53+
- Qwen3 backbone weights
54+
- a text output head
55+
- 32 audio output heads
56+
57+
For example:
58+
59+
- `out/moss_delay_firstclass_f16.gguf`
60+
61+
You can generate it directly from the full Hugging Face MOSS-TTS model directory:
62+
63+
```bash
64+
huggingface-cli download OpenMOSS-Team/MOSS-TTS --local-dir /path/to/MOSS-TTS-hf
65+
66+
python convert_hf_to_gguf.py \
67+
/path/to/MOSS-TTS-hf \
68+
--outfile /path/to/moss_delay_firstclass_f16.gguf \
69+
--outtype f16
70+
```
71+
72+
Important:
73+
74+
- The `--model-gguf` file used by this e2e pipeline is a **special first-class MOSS-TTS-Delay GGUF** generated from the full `OpenMOSS-Team/MOSS-TTS` Hugging Face model directory with the command above.
75+
- It is **not** the same thing as a generic GGUF downloaded from `OpenMOSS/MOSS-TTS-GGUF`.
76+
- Do not point this pipeline at a file from `OpenMOSS/MOSS-TTS-GGUF` unless that file was explicitly produced as a first-class MOSS-TTS-Delay GGUF for this `llama.cpp` implementation.
77+
78+
### Step 2: Prepare the tokenizer directory
79+
80+
You need a tokenizer directory containing at least:
81+
82+
- `tokenizer.json`
83+
84+
For example:
85+
86+
- `weights/extracted/qwen3_backbone/`
87+
88+
### Step 3: Prepare the ONNX audio tokenizer
89+
90+
You need both ONNX files:
91+
92+
- `encoder.onnx`
93+
- `decoder.onnx`
94+
95+
For example:
96+
97+
- `weights/MOSS-Audio-Tokenizer-ONNX/encoder.onnx`
98+
- `weights/MOSS-Audio-Tokenizer-ONNX/decoder.onnx`
99+
100+
### Step 4: Make the MOSS-TTS repository visible
101+
102+
The helper scripts import:
103+
104+
- `moss_tts_delay.llama_cpp.processor`
105+
- `moss_audio_tokenizer.onnx`
106+
107+
You can provide the repository path like this:
108+
109+
```bash
110+
export MOSS_TTS_DIR=/path/to/MOSS-TTS
111+
```
112+
113+
or:
114+
115+
```bash
116+
python tools/tts/moss-tts-firstclass-e2e.py --moss-tts-dir /path/to/MOSS-TTS ...
117+
```
118+
119+
## Usage
120+
121+
### CLI
122+
123+
```bash
124+
# Voice cloning: text + reference audio -> wav
125+
python tools/tts/moss-tts-firstclass-e2e.py \
126+
--model-gguf /path/to/moss_delay_firstclass.gguf \
127+
--moss-tts-dir /path/to/MOSS-TTS \
128+
--tokenizer-dir /path/to/tokenizer_dir \
129+
--onnx-encoder /path/to/encoder.onnx \
130+
--onnx-decoder /path/to/decoder.onnx \
131+
--text-file /path/to/text.txt \
132+
--reference-audio /path/to/reference_24k.wav \
133+
--output-wav /path/to/output.wav
134+
135+
# Direct generation without reference audio
136+
python tools/tts/moss-tts-firstclass-e2e.py \
137+
--model-gguf /path/to/moss_delay_firstclass.gguf \
138+
--moss-tts-dir /path/to/MOSS-TTS \
139+
--tokenizer-dir /path/to/tokenizer_dir \
140+
--onnx-encoder /path/to/encoder.onnx \
141+
--onnx-decoder /path/to/decoder.onnx \
142+
--text "Hello, world!" \
143+
--output-wav /path/to/output.wav
144+
145+
# Build llama-moss-tts before running
146+
python tools/tts/moss-tts-firstclass-e2e.py \
147+
--build \
148+
--model-gguf /path/to/moss_delay_firstclass.gguf \
149+
--moss-tts-dir /path/to/MOSS-TTS \
150+
--tokenizer-dir /path/to/tokenizer_dir \
151+
--onnx-encoder /path/to/encoder.onnx \
152+
--onnx-decoder /path/to/decoder.onnx \
153+
--text "Hello!" \
154+
--output-wav /path/to/output.wav
155+
```
156+
157+
## Key Options
158+
159+
| Option | Values | Description |
160+
|------|------|------|
161+
| `--model-gguf` | path | First-class MOSS-TTS GGUF model |
162+
| `--moss-tts-dir` | path | Local `MOSS-TTS` repository root |
163+
| `--tokenizer-dir` | path | Directory containing `tokenizer.json` |
164+
| `--onnx-encoder` | path | Audio tokenizer encoder ONNX |
165+
| `--onnx-decoder` | path | Audio tokenizer decoder ONNX |
166+
| `--text` / `--text-file` | string / path | Input text, choose exactly one |
167+
| `--reference-audio` | path | Optional reference audio; if provided, it must be 24 kHz |
168+
| `--language` | `zh` / `en` / tag | Language tag passed to the prompt builder |
169+
| `--max-new-tokens` | int | Maximum generation steps |
170+
| `--text-temperature` | float | Text-channel sampling temperature, default `1.5` |
171+
| `--audio-temperature` | float | Audio-channel sampling temperature, default `1.7` |
172+
| `--n-gpu-layers` | `-1` / `0` / `N` | GPU offload layers, default `-1` |
173+
| `--audio-decoder-cpu` | flag | Force ONNX waveform decoding on CPU |
174+
| `--cpu-audio-encode` | flag | Force ONNX reference-audio encoding on CPU |
175+
| `--build` | flag | Build `llama-moss-tts` before running |
176+
177+
## Architecture
178+
179+
```text
180+
Input text (+ optional reference wav)
181+
|
182+
v
183+
moss-tts-build-generation-ref.py
184+
|
185+
|- tokenizes text with the Qwen3 tokenizer
186+
|- optionally encodes the reference wav into audio codes with ONNX
187+
|- calls the prompt builder from the local MOSS-TTS repo
188+
v
189+
generation.ref.bin
190+
|
191+
v
192+
llama-moss-tts
193+
|
194+
|- loads the first-class GGUF model
195+
|- performs multi-channel embedding lookup in-graph
196+
|- runs the Qwen3 backbone inside llama.cpp
197+
|- samples multi-head logits
198+
|- performs delay-pattern decoding in C++
199+
v
200+
raw.codes.bin
201+
|
202+
v
203+
moss-tts-audio-decode.py
204+
|
205+
|- decodes raw audio codes into waveform with ONNX
206+
v
207+
wav
208+
```
209+
210+
## Temporary Artifacts
211+
212+
The e2e script creates a temporary directory and removes it automatically after the run.
213+
214+
The following intermediate files are not kept:
215+
216+
- `generation.ref.bin`
217+
- `raw.codes.bin`
218+
219+
The only visible artifact after the run is the output wav you requested.
220+
221+
## Output
222+
223+
At the end of a successful run, the script prints:
224+
225+
- `wav` — output path
226+
- `wav_info` — sample rate, channel count, frame count, and duration
227+
228+
## File Structure
229+
230+
```text
231+
llama.cpp/
232+
├── docs/
233+
│ ├── moss-tts-firstclass-e2e.md
234+
│ └── moss-tts-firstclass-e2e_zh.md
235+
├── tools/tts/
236+
│ ├── moss-tts-firstclass-e2e.py # End-to-end wrapper
237+
│ ├── moss-tts-build-generation-ref.py # Prompt / input builder
238+
│ ├── moss-tts-audio-decode.py # ONNX audio decode helper
239+
│ └── moss-tts.cpp # llama-moss-tts implementation
240+
└── build/bin/
241+
└── llama-moss-tts
242+
```

0 commit comments

Comments
 (0)