schema_dsl(..., multi=True) parameter, refs #790

This commit is contained in:
Simon Willison 2025-02-27 10:28:42 -08:00
parent 8d32b71ef1
commit eb2b243fdf
4 changed files with 45 additions and 7 deletions

View file

@ -126,7 +126,14 @@ print(model.prompt(
schema=llm.schema_dsl("name, age int, bio")
))
```
Pass `multi=True` to generate a schema that returns multiple items matching that specification:
```python
print(model.prompt(
"Describe 3 nice dogs with surprising names",
schema=llm.schema_dsl("name, age int, bio", multi=True)
))
```
(python-api-model-options)=
### Model options

View file

@ -46,6 +46,7 @@ from .utils import (
output_rows_as_json,
resolve_schema_input,
schema_summary,
multi_schema,
)
import base64
import httpx
@ -307,11 +308,7 @@ def prompt(
if schema_multi:
# Convert that schema into multiple "items" of the same schema
schema = {
"type": "object",
"properties": {"items": {"type": "array", "items": schema}},
"required": ["items"],
}
schema = multi_schema(schema)
model_aliases = get_model_aliases()

View file

@ -303,13 +303,14 @@ def schema_summary(schema: dict) -> str:
return ""
def schema_dsl(schema_dsl: str) -> Dict[str, Any]:
def schema_dsl(schema_dsl: str, multi: bool = False) -> Dict[str, Any]:
"""
Build a JSON schema from a concise schema string.
Args:
schema_dsl: A string representing a schema in the concise format.
Can be comma-separated or newline-separated.
multi: Boolean, return a schema for an "items" array of these
Returns:
A dictionary representing the JSON schema.
@ -364,4 +365,16 @@ def schema_dsl(schema_dsl: str) -> Dict[str, Any]:
# Add field to required list
json_schema["required"].append(field_name)
return json_schema
if multi:
return multi_schema(json_schema)
else:
return json_schema
def multi_schema(schema: dict) -> dict:
"Wrap JSON schema in an 'items': [] array"
return {
"type": "object",
"properties": {"items": {"type": "array", "items": schema}},
"required": ["items"],
}

View file

@ -212,3 +212,24 @@ def test_extract_fenced_code_block(input, last, expected):
def test_schema_dsl(schema, expected):
result = schema_dsl(schema)
assert result == expected
def test_schema_dsl_multi():
result = schema_dsl("name, age int: The age", multi=True)
assert result == {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer", "description": "The age"},
},
"required": ["name", "age"],
},
}
},
"required": ["items"],
}