Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions dell_ai/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AuthenticationError,
ResourceNotFoundError,
ValidationError,
GatedModelError,
)
from dell_ai.cli.utils import (
get_client,
Expand Down Expand Up @@ -160,6 +161,33 @@ def models_show(model_id: str) -> None:
print_error(f"Failed to get model information: {str(e)}")


@models_app.command("check-access")
def models_check_access(model_id: str) -> None:
"""
Check if you have access to a specific model.

This is useful for determining if you can access a gated model
before attempting to use it. If you don't have access, this command
will provide a URL where you can request access.

Args:
model_id: The model ID in the format "organization/model_name"
"""
try:
client = get_client()
client.check_model_access(model_id)
typer.echo(f"✅ You have access to model: {model_id}")
except GatedModelError as e:
typer.echo(f"❌ {str(e)}")
raise typer.Exit(code=1)
except ValidationError as e:
print_error(f"Invalid model ID: {str(e)}")
except AuthenticationError as e:
print_error(f"Authentication error: {str(e)}")
except Exception as e:
print_error(f"Error checking model access: {str(e)}")


@platforms_app.command("list")
def platforms_list() -> None:
"""
Expand Down
30 changes: 30 additions & 0 deletions dell_ai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,36 @@ def get_platform(self, platform_id: str) -> "Platform":

return platforms.get_platform(self, platform_id)

def check_model_access(self, model_id: str) -> bool:
"""
Check if the user has access to a model.

This method can be used to verify if the authenticated user has permission
to access a model before attempting to use it. This is particularly useful
for gated models that require specific access permissions.

Args:
model_id: The model ID in the format "organization/model_name"

Returns:
True if the user has access, False if an unexpected error occurs

Raises:
ValidationError: If the model_id format is invalid
GatedModelError: If the model is gated and the user doesn't have access
AuthenticationError: If authentication fails or no token is available
"""
from dell_ai import models

if not self.token:
from dell_ai.exceptions import AuthenticationError

raise AuthenticationError(
"No authentication token available. Please login first."
)

return models.check_model_access(self, model_id)

def get_deployment_snippet(
self,
model_id: str,
Expand Down
24 changes: 24 additions & 0 deletions dell_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,30 @@ def __init__(self, resource_type, resource_id, original_error=None):
super().__init__(message, original_error)


class GatedModelError(DellAIError):
"""Raised when attempting to access a gated model without permission.

This exception is raised when a user tries to access a model that has gated access
and the user does not have the required permissions.
"""

def __init__(self, model_id, access_url=None, original_error=None):
"""Initialize the gated model error.

Args:
model_id: The ID of the gated model.
access_url: URL where the user can request access to the model, if applicable.
original_error: The original exception that caused this error, if any.
"""
self.model_id = model_id
self.access_url = access_url or f"https://huggingface.co/{model_id}"

message = f"Access to model '{model_id}' is restricted and you do not have permission to access it."
message += f"\nVisit {self.access_url} to request access."

super().__init__(message, original_error)


class ValidationError(DellAIError):
"""Raised when input validation fails.

Expand Down
39 changes: 38 additions & 1 deletion dell_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Dict, List, TYPE_CHECKING

from pydantic import BaseModel, Field
from huggingface_hub import model_info
from huggingface_hub.utils import GatedRepoError as HFGatedRepoError

from dell_ai import constants
from dell_ai.exceptions import ResourceNotFoundError, ValidationError
from dell_ai.exceptions import ResourceNotFoundError, ValidationError, GatedModelError

if TYPE_CHECKING:
from dell_ai.client import DellAIClient
Expand Down Expand Up @@ -107,3 +109,38 @@ def get_model(client: "DellAIClient", model_id: str) -> Model:
except ResourceNotFoundError:
# Reraise with more specific information
raise ResourceNotFoundError("model", model_id)


def check_model_access(client: "DellAIClient", model_id: str) -> bool:
"""
Check if the user has access to the specified model.

Args:
client: The Dell AI client
model_id: The model ID in the format "organization/model_name"

Returns:
True if the user has access to the model, False otherwise

Raises:
ValidationError: If the model_id format is invalid
GatedModelError: If the model is gated and the user doesn't have access
"""
# Validate model_id format
if "/" not in model_id:
raise ValidationError(
"Invalid model ID format. Expected format: 'organization/model_name'",
parameter="model_id",
)

try:
# Use the Hugging Face Hub library to check if the model is accessible
# This will throw a GatedRepoError if the model is gated and the user doesn't have access
model_info(model_id, token=client.token)
return True
except HFGatedRepoError as e:
# The model is gated and the user doesn't have access
raise GatedModelError(model_id, original_error=e)
except Exception:
# For any other error, we'll assume it's not specifically a access issue
return False
17 changes: 13 additions & 4 deletions dell_ai/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dell_ai import constants
from dell_ai.client import DellAIClient
from dell_ai.exceptions import ValidationError, ResourceNotFoundError
from dell_ai.exceptions import ValidationError, ResourceNotFoundError, GatedModelError
from dell_ai import models


Expand Down Expand Up @@ -198,21 +198,30 @@ def get_deployment_snippet(
Raises:
ValidationError: If any of the input parameters are invalid
ResourceNotFoundError: If the model, platform, or configuration is not found
GatedModelError: If the model is gated and the user doesn't have access
"""
# Step 1: Validate basic request parameters
_validate_request_schema(model_id, platform_id, engine, num_gpus, num_replicas)

# Step 2: Parse and validate model ID format
creator_name, model_name = _validate_model_id_format(model_id)

# Step 3: Validate model and platform compatibility if the model exists
# Step 3: Check if the user has access to the model
# This will raise GatedModelError if the model is gated and the user doesn't have access
try:
models.check_model_access(client, model_id)
except GatedModelError:
# Re-raise the exception - we let it bubble up
raise

# Step 4: Validate model and platform compatibility if the model exists
try:
_validate_model_platform_compatibility(client, model_id, platform_id, num_gpus)
except ResourceNotFoundError:
# We'll handle this during the API request
pass

# Step 4: Build API path and query parameters
# Step 5: Build API path and query parameters
path = f"{constants.SNIPPETS_ENDPOINT}/models/{creator_name}/{model_name}/deploy"
params = {
"sku": platform_id, # API still expects "sku" as the parameter name
Expand All @@ -221,7 +230,7 @@ def get_deployment_snippet(
"gpus": num_gpus,
}

# Step 5: Make API request and handle errors
# Step 6: Make API request and handle errors
try:
response = client._make_request("GET", path, params=params)
return SnippetResponse(snippet=response.get("snippet", "")).snippet
Expand Down