diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 88f84df194..1e6688c403 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -135,7 +135,7 @@ class InferenceClient: Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`. + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"textclf"`, `"together"`, `"wavespeed"` or `"zai-org"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 7fde5a9fc2..e7973a1728 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -126,7 +126,7 @@ class AsyncInferenceClient: Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`. + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"nvidia"`, `"openai"`, `"ovhcloud"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"textclf"`, `"together"`, `"wavespeed"` or `"zai-org"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 3e7788fc07..50b29af17c 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -50,6 +50,7 @@ ) from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask +from .textclf import TextCLFConversationalTask, TextCLFTextGenerationTask from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask from .wavespeed import ( WavespeedAIImageToImageTask, @@ -84,6 +85,7 @@ "replicate", "sambanova", "scaleway", + "textclf", "together", "wavespeed", "zai-org", @@ -200,6 +202,10 @@ "conversational": ScalewayConversationalTask(), "feature-extraction": ScalewayFeatureExtractionTask(), }, + "textclf": { + "text-generation": TextCLFTextGenerationTask(), + "conversational": TextCLFConversationalTask(), + }, "together": { "text-to-image": TogetherTextToImageTask(), "conversational": TogetherConversationalTask(), diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 872cfc89f6..237499cbe5 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -37,6 +37,7 @@ "replicate": {}, "sambanova": {}, "scaleway": {}, + "textclf": {}, "together": {}, "wavespeed": {}, "zai-org": {}, diff --git a/src/huggingface_hub/inference/_providers/textclf.py b/src/huggingface_hub/inference/_providers/textclf.py new file mode 100644 index 0000000000..6acb96ca1b --- /dev/null +++ b/src/huggingface_hub/inference/_providers/textclf.py @@ -0,0 +1,30 @@ +from typing import Any, Optional, Union + +from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._providers._common import ( + BaseConversationalTask, + BaseTextGenerationTask, +) + +_PROVIDER = "textclf" +_BASE_URL = "https://api.textclf.com" + + +class TextCLFTextGenerationTask(BaseTextGenerationTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) + + def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any: + output = _as_dict(response)["choices"][0] + return { + "generated_text": output["text"], + "details": { + "finish_reason": output.get("finish_reason"), + "seed": output.get("seed"), + }, + } + + +class TextCLFConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) \ No newline at end of file diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 4bbf5895ab..ae8210bda9 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -126,6 +126,10 @@ "sambanova": { "conversational": "meta-llama/Llama-3.1-8B-Instruct", }, + "textclf": { + "text-generation": "meta-llama/Llama-3.1-8B-Instruct", + "conversational": "meta-llama/Llama-3.1-8B-Instruct", + }, } CHAT_COMPLETION_MODEL = "HuggingFaceH4/zephyr-7b-beta" diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 80cb589eed..3d51e511b4 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -57,6 +57,7 @@ ) from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from huggingface_hub.inference._providers.scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask +from huggingface_hub.inference._providers.textclf import TextCLFConversationalTask, TextCLFTextGenerationTask from huggingface_hub.inference._providers.together import TogetherTextToImageTask from huggingface_hub.inference._providers.wavespeed import ( WavespeedAIImageToImageTask, @@ -1728,6 +1729,16 @@ def test_prepare_url_feature_extraction(self): == "https://router.huggingface.co/sambanova/v1/embeddings" ) +class TestTextCLFProvider: + def test_prepare_url_text_generation(self): + helper = TextCLFTextGenerationTask() + url = helper._prepare_url("textclf_token", "username/repo_name") + assert url == "https://api.textclf.com/v1/chat/completions" + + def test_prepare_url_conversational(self): + helper = TextCLFConversationalTask() + url = helper._prepare_url("textclf_token", "username/repo_name") + assert url == "https://api.textclf.com/v1/chat/completions" class TestTogetherProvider: def test_prepare_route_text_to_image(self):