Generation#

Sampler(transformer, tokenizer, cache_config)

Sampler for transformer model.

CacheConfig(cache_size, num_layers, ...)

Configuration for the KV cache.


class tunix.Sampler(transformer: Module, tokenizer: Any, cache_config: CacheConfig, image_processor: ImageProcessor | None = None)#

Sampler for transformer model.

property dtype: dtype#
init_sample_state(all_input_ids: Array, total_sampling_steps: int, include_logits: bool, forbidden_token_ids: tuple[int, ...] | None, temperature: float, top_p: float | None, top_k: int | None, seed: Array, beam_size: int | None) _SamplingState#

Initializes the sampling state given input prompts.

model_def_and_state() tuple[NodeDef, State]#

Returns the transformer graphdef and state.

tokenize(input_string: str) ndarray | list[int]#

Tokenizes the input string.

property transformer: Module#

Returns the transformer module used by the sampler.

property transformer_state: State#

Returns the transformer state used by the sampler.


class tunix.CacheConfig(cache_size: int, num_layers: int, num_kv_heads: int, head_dim: int)#

Configuration for the KV cache.

cache_size: int#
head_dim: int#
num_kv_heads: int#
num_layers: int#