NumPy decoding docs, plus extra tests for llm.encode/decode

!stable-docs

Refs https://discord.com/channels/823971286308356157/1128504153841336370/1151975583237034056
This commit is contained in:
Simon Willison 2023-09-14 14:01:27 -07:00
parent d70c0dba43
commit 356fcb72f6
3 changed files with 29 additions and 0 deletions

View file

@ -20,3 +20,12 @@ def decode(binary):
```
These functions are available as `llm.encode()` and `llm.decode()`.
If you are using [NumPy](https://numpy.org/) you can decode one of these binary values like this:
```python
import numpy as np
numpy_array = np.frombuffer(value, "<f4")
```
The `<f4` format string here ensures NumPy will treat the data as a little-endian sequence of 32-bit floats.

View file

@ -51,6 +51,7 @@ setup(
extras_require={
"test": [
"pytest",
"numpy",
"requests-mock",
"cogapp",
"mypy",

View file

@ -0,0 +1,19 @@
import llm
import pytest
import numpy as np
@pytest.mark.parametrize(
"array",
(
(0.0, 1.0, 1.5),
(3423.0, 222.0, -1234.5),
),
)
def test_roundtrip(array):
encoded = llm.encode(array)
decoded = llm.decode(encoded)
assert decoded == array
# Try with numpy as well
numpy_decoded = np.frombuffer(encoded, "<f4")
assert tuple(numpy_decoded.tolist()) == array