From cd722f653bf0d3d43aea5383b3477253be5bfbd7 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 28 Oct 2024 12:35:51 -0700 Subject: [PATCH] Redact base64 data from _prompt_json, refs #587 --- llm/default_plugins/openai_models.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 651d37f..b7b1c46 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -346,7 +346,6 @@ class Chat(Model): ) messages.append({"role": "user", "content": attachment_message}) - response._prompt_json = {"messages": messages} kwargs = self.build_kwargs(prompt) client = self.get_client() if stream: @@ -372,6 +371,7 @@ class Chat(Model): ) response.response_json = remove_dict_none_values(completion.model_dump()) yield completion.choices[0].message.content + response._prompt_json = redact_data_urls({"messages": messages}) def get_client(self): kwargs = {} @@ -431,7 +431,6 @@ class Completion(Chat): messages.append(prev_response.prompt.prompt) messages.append(prev_response.text()) messages.append(prompt.prompt) - response._prompt_json = {"messages": messages} kwargs = self.build_kwargs(prompt) client = self.get_client() if stream: @@ -459,6 +458,7 @@ class Completion(Chat): ) response.response_json = remove_dict_none_values(completion.model_dump()) yield completion.choices[0].text + response._prompt_json = redact_data_urls({"messages": messages}) def not_nulls(data) -> dict: @@ -506,3 +506,20 @@ def combine_chunks(chunks: List) -> dict: combined[key] = value return combined + + +def redact_data_urls(input_dict): + """ + Recursively search through the input dictionary for any 'image_url' keys + and modify the 'url' value to be just 'data:...'. + """ + if isinstance(input_dict, dict): + for key, value in input_dict.items(): + if key == "image_url" and isinstance(value, dict) and "url" in value: + value["url"] = "data:..." + else: + redact_data_urls(value) + elif isinstance(input_dict, list): + for item in input_dict: + redact_data_urls(item) + return input_dict