Spaces:
Runtime error
Runtime error
VenkateshRoshan
commited on
Commit
·
b1d9c58
1
Parent(s):
f0c8e4c
deployment file update
Browse files- .github/workflows/deploy.yml +44 -19
- app.py +61 -39
- dockerfile +4 -4
- local_app.py +150 -0
- requirements.txt +0 -1
- src/deploy_sagemaker.py +41 -76
- src/local_deploy_sagemaker.py +89 -0
- src/push_to_s3.py +52 -0
.github/workflows/deploy.yml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
name: Deploy to
|
2 |
|
3 |
on:
|
4 |
push:
|
@@ -6,21 +6,46 @@ on:
|
|
6 |
- main
|
7 |
|
8 |
jobs:
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy to SageMaker
|
2 |
|
3 |
on:
|
4 |
push:
|
|
|
6 |
- main
|
7 |
|
8 |
jobs:
|
9 |
+
deploy:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout code
|
14 |
+
uses: actions/checkout@v2
|
15 |
+
|
16 |
+
- name: Setup Python
|
17 |
+
uses: actions/setup-python@v3
|
18 |
+
with:
|
19 |
+
python-version: '3.10'
|
20 |
+
|
21 |
+
- name: Install dependencies
|
22 |
+
run: |
|
23 |
+
python -m pip install --upgrade pip
|
24 |
+
pip install --no-cache-dir --upgrade pip
|
25 |
+
pip install --no-cache-dir torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
26 |
+
pip install --no-cache-dir -r requirements.txt
|
27 |
+
|
28 |
+
- name: Login to AWS
|
29 |
+
uses: aws-actions/configure-aws-credentials@v2
|
30 |
+
with:
|
31 |
+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
32 |
+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
33 |
+
aws-region: ${{ secrets.AWS_REGION }}
|
34 |
+
|
35 |
+
- name: Login to Amazon ECR
|
36 |
+
id: login-ecr
|
37 |
+
uses: aws-actions/amazon-ecr-login@v1
|
38 |
+
|
39 |
+
- name: Build and push Docker image
|
40 |
+
run: |
|
41 |
+
docker build -t ${{ secrets.ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_REGION }}.amazonaws.com/customer_support_bot:latest .
|
42 |
+
docker push ${{ secrets.ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_REGION }}.amazonaws.com/customer_support_bot:latest
|
43 |
+
|
44 |
+
- name: Deploy model to SageMaker
|
45 |
+
run: |
|
46 |
+
python deploy_sagemaker.py \
|
47 |
+
--account_id ${{ secrets.ACCOUNT_ID }} \
|
48 |
+
--region ${{ secrets.AWS_REGION }} \
|
49 |
+
--role_arn ${{ secrets.SAGEMAKER_ROLE_ARN }} \
|
50 |
+
--ecr_repo_name "customer_support_bot" \
|
51 |
+
--endpoint_name "customer-support-chatbot"
|
app.py
CHANGED
@@ -1,54 +1,75 @@
|
|
1 |
-
import json
|
2 |
import psutil
|
3 |
import torch
|
4 |
-
import
|
5 |
-
from transformers import AutoTokenizer
|
6 |
import gradio as gr
|
7 |
import os
|
|
|
8 |
from typing import List, Tuple
|
9 |
|
10 |
class CustomerSupportBot:
|
11 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
self.process = psutil.Process(os.getpid())
|
13 |
-
|
14 |
-
self.
|
15 |
-
# self.tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use the tokenizer appropriate to your model
|
16 |
-
self.endpoint_name = endpoint_name
|
17 |
-
self.sagemaker_runtime = boto3.client('runtime.sagemaker')
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
print(
|
30 |
-
|
31 |
-
|
32 |
-
print(f'JSON Payload: {json_payload}')
|
33 |
-
# Call the SageMaker endpoint for inference
|
34 |
-
response = self.sagemaker_runtime.invoke_endpoint(
|
35 |
-
EndpointName=self.endpoint_name,
|
36 |
-
ContentType='application/json',
|
37 |
-
Body=json_payload # Send the JSON string here
|
38 |
-
)
|
39 |
-
print(f'Response: {response}')
|
40 |
|
41 |
-
#
|
42 |
-
|
43 |
-
|
44 |
|
45 |
-
|
46 |
-
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
return response
|
53 |
except Exception as e:
|
54 |
return f"An error occurred: {str(e)}"
|
@@ -60,8 +81,9 @@ class CustomerSupportBot:
|
|
60 |
}
|
61 |
return usage
|
62 |
|
|
|
63 |
def create_chat_interface():
|
64 |
-
bot = CustomerSupportBot()
|
65 |
|
66 |
def predict(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]:
|
67 |
if not message:
|
|
|
1 |
+
import json
|
2 |
import psutil
|
3 |
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
5 |
import gradio as gr
|
6 |
import os
|
7 |
+
import tarfile
|
8 |
from typing import List, Tuple
|
9 |
|
10 |
class CustomerSupportBot:
|
11 |
+
def __init__(self, model_path="models/customer_support_gpt"):
|
12 |
+
"""
|
13 |
+
Initialize the customer support bot with the fine-tuned model.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
model_path (str): Path to the saved model and tokenizer
|
17 |
+
"""
|
18 |
self.process = psutil.Process(os.getpid())
|
19 |
+
self.model_path = model_path
|
20 |
+
self.model_file_path = os.path.join(self.model_path, "model.tar.gz")
|
|
|
|
|
|
|
21 |
|
22 |
+
# Download and load the model
|
23 |
+
self.download_and_load_model()
|
24 |
+
|
25 |
+
def download_and_load_model(self):
|
26 |
+
# Check if the model directory exists
|
27 |
+
if not os.path.exists(self.model_path):
|
28 |
+
os.makedirs(self.model_path)
|
29 |
+
|
30 |
+
# Download model.tar.gz from S3 if not already downloaded
|
31 |
+
if not os.path.exists(self.model_file_path):
|
32 |
+
print("Downloading model from S3...")
|
33 |
+
self.s3.download_file(self.bucket_name, self.model_key, self.model_file_path)
|
34 |
+
print("Download complete. Extracting model files...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Extract the model files
|
37 |
+
with tarfile.open(self.model_file_path, "r:gz") as tar:
|
38 |
+
tar.extractall(self.model_path)
|
39 |
|
40 |
+
# Load the model and tokenizer from extracted files
|
41 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
42 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_path)
|
43 |
+
print("Model and tokenizer loaded successfully.")
|
44 |
|
45 |
+
# Move model to GPU if available
|
46 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
+
self.model = self.model.to(self.device)
|
48 |
|
49 |
+
def generate_response(self, message: str, max_length=100, temperature=0.7) -> str:
|
50 |
+
try:
|
51 |
+
input_text = f"Instruction: {message}\nResponse:"
|
52 |
+
|
53 |
+
# Tokenize input text
|
54 |
+
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)
|
55 |
+
|
56 |
+
# Generate response using the model
|
57 |
+
with torch.no_grad():
|
58 |
+
outputs = self.model.generate(
|
59 |
+
**inputs,
|
60 |
+
max_length=max_length,
|
61 |
+
temperature=temperature,
|
62 |
+
num_return_sequences=1,
|
63 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
64 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
65 |
+
do_sample=True,
|
66 |
+
top_p=0.95,
|
67 |
+
top_k=50
|
68 |
+
)
|
69 |
+
|
70 |
+
# Decode and format the response
|
71 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
72 |
+
response = response.split("Response:")[-1].strip()
|
73 |
return response
|
74 |
except Exception as e:
|
75 |
return f"An error occurred: {str(e)}"
|
|
|
81 |
}
|
82 |
return usage
|
83 |
|
84 |
+
|
85 |
def create_chat_interface():
|
86 |
+
bot = CustomerSupportBot(model_path="/app/models")
|
87 |
|
88 |
def predict(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]:
|
89 |
if not message:
|
dockerfile
CHANGED
@@ -17,11 +17,11 @@ RUN pip install --no-cache-dir --upgrade pip
|
|
17 |
RUN pip install --no-cache-dir torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
18 |
RUN pip install --no-cache-dir -r requirements.txt
|
19 |
|
20 |
-
# Copy .env file to the working directory
|
21 |
-
COPY .env /app/.env
|
22 |
|
23 |
-
# Set environment variables from .env file
|
24 |
-
ENV $(cat /app/.env | xargs)
|
25 |
|
26 |
# Expose port 7860
|
27 |
EXPOSE 7860
|
|
|
17 |
RUN pip install --no-cache-dir torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
18 |
RUN pip install --no-cache-dir -r requirements.txt
|
19 |
|
20 |
+
# # Copy .env file to the working directory
|
21 |
+
# COPY .env /app/.env
|
22 |
|
23 |
+
# # Set environment variables from .env file
|
24 |
+
# ENV $(cat /app/.env | xargs)
|
25 |
|
26 |
# Expose port 7860
|
27 |
EXPOSE 7860
|
local_app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json # Add this import
|
2 |
+
import psutil
|
3 |
+
import torch
|
4 |
+
import boto3
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
import gradio as gr
|
7 |
+
import os
|
8 |
+
from typing import List, Tuple
|
9 |
+
|
10 |
+
class CustomerSupportBot:
|
11 |
+
def __init__(self, endpoint_name="customer-support-gpt-2024-11-10-00-30-03-555"):
|
12 |
+
self.process = psutil.Process(os.getpid())
|
13 |
+
model_name = "EleutherAI/gpt-neo-125M"
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use the tokenizer appropriate to your model
|
16 |
+
self.endpoint_name = endpoint_name
|
17 |
+
self.sagemaker_runtime = boto3.client('runtime.sagemaker')
|
18 |
+
|
19 |
+
def generate_response(self, message: str) -> str:
|
20 |
+
try:
|
21 |
+
input_text = f"Instruction: {message}\nResponse:"
|
22 |
+
|
23 |
+
# Prepare payload for SageMaker endpoint
|
24 |
+
payload = {
|
25 |
+
# "inputs": inputs['input_ids'].tolist()[0],
|
26 |
+
'inputs': input_text,
|
27 |
+
# You can include other parameters if needed (e.g., attention_mask)
|
28 |
+
}
|
29 |
+
print(f'Payload: {payload}')
|
30 |
+
# Convert the payload to a JSON string before sending
|
31 |
+
json_payload = json.dumps(payload) # Use json.dumps() to serialize the payload
|
32 |
+
print(f'JSON Payload: {json_payload}')
|
33 |
+
# Call the SageMaker endpoint for inference
|
34 |
+
response = self.sagemaker_runtime.invoke_endpoint(
|
35 |
+
EndpointName=self.endpoint_name,
|
36 |
+
ContentType='application/json',
|
37 |
+
Body=json_payload # Send the JSON string here
|
38 |
+
)
|
39 |
+
print(f'Response: {response}')
|
40 |
+
|
41 |
+
# Process the response
|
42 |
+
result = response['Body'].read().decode('utf-8')
|
43 |
+
parsed_result = json.loads(result)
|
44 |
+
|
45 |
+
# Extract the generated text from the first element in the list
|
46 |
+
generated_text = parsed_result[0]['generated_text']
|
47 |
+
|
48 |
+
# Split the string to get the response part after 'Response:'
|
49 |
+
response = generated_text.split('Response:')[1].strip()
|
50 |
+
|
51 |
+
# return the extracted response
|
52 |
+
return response
|
53 |
+
except Exception as e:
|
54 |
+
return f"An error occurred: {str(e)}"
|
55 |
+
|
56 |
+
def monitor_resources(self) -> dict:
|
57 |
+
usage = {
|
58 |
+
"CPU (%)": self.process.cpu_percent(interval=1),
|
59 |
+
"RAM (GB)": self.process.memory_info().rss / (1024 ** 3)
|
60 |
+
}
|
61 |
+
return usage
|
62 |
+
|
63 |
+
def create_chat_interface():
|
64 |
+
bot = CustomerSupportBot()
|
65 |
+
|
66 |
+
def predict(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]:
|
67 |
+
if not message:
|
68 |
+
return "", history
|
69 |
+
|
70 |
+
bot_response = bot.generate_response(message)
|
71 |
+
|
72 |
+
# Log resource usage
|
73 |
+
usage = bot.monitor_resources()
|
74 |
+
print("Resource Usage:", usage)
|
75 |
+
|
76 |
+
history.append((message, bot_response))
|
77 |
+
return "", history
|
78 |
+
|
79 |
+
# Create the Gradio interface with custom CSS
|
80 |
+
with gr.Blocks(css="""
|
81 |
+
.message-box {
|
82 |
+
margin-bottom: 10px;
|
83 |
+
}
|
84 |
+
.button-row {
|
85 |
+
display: flex;
|
86 |
+
gap: 10px;
|
87 |
+
margin-top: 10px;
|
88 |
+
}
|
89 |
+
""") as interface:
|
90 |
+
gr.Markdown("# Customer Support Chatbot")
|
91 |
+
gr.Markdown("Welcome! How can I assist you today?")
|
92 |
+
|
93 |
+
chatbot = gr.Chatbot(
|
94 |
+
label="Chat History",
|
95 |
+
height=500,
|
96 |
+
elem_classes="message-box"
|
97 |
+
)
|
98 |
+
|
99 |
+
with gr.Row():
|
100 |
+
msg = gr.Textbox(
|
101 |
+
label="Your Message",
|
102 |
+
placeholder="Type your message here...",
|
103 |
+
lines=2,
|
104 |
+
elem_classes="message-box"
|
105 |
+
)
|
106 |
+
|
107 |
+
with gr.Row(elem_classes="button-row"):
|
108 |
+
submit = gr.Button("Send Message", variant="primary")
|
109 |
+
clear = gr.ClearButton([msg, chatbot], value="Clear Chat")
|
110 |
+
|
111 |
+
# Add example queries in a separate row
|
112 |
+
with gr.Row():
|
113 |
+
gr.Examples(
|
114 |
+
examples=[
|
115 |
+
"How do I reset my password?",
|
116 |
+
"What are your shipping policies?",
|
117 |
+
"I want to return a product.",
|
118 |
+
"How can I track my order?",
|
119 |
+
"What payment methods do you accept?"
|
120 |
+
],
|
121 |
+
inputs=msg,
|
122 |
+
label="Example Questions"
|
123 |
+
)
|
124 |
+
|
125 |
+
# Set up event handlers
|
126 |
+
submit_click = submit.click(
|
127 |
+
predict,
|
128 |
+
inputs=[msg, chatbot],
|
129 |
+
outputs=[msg, chatbot]
|
130 |
+
)
|
131 |
+
|
132 |
+
msg.submit(
|
133 |
+
predict,
|
134 |
+
inputs=[msg, chatbot],
|
135 |
+
outputs=[msg, chatbot]
|
136 |
+
)
|
137 |
+
|
138 |
+
# Add keyboard shortcut for submit
|
139 |
+
msg.change(lambda x: gr.update(interactive=bool(x.strip())), inputs=[msg], outputs=[submit])
|
140 |
+
|
141 |
+
return interface
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
demo = create_chat_interface()
|
145 |
+
demo.launch(
|
146 |
+
share=True,
|
147 |
+
server_name="0.0.0.0", # Makes the server accessible from other machines
|
148 |
+
server_port=7860, # Specify the port
|
149 |
+
debug=True
|
150 |
+
)
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
transformers==4.37
|
2 |
-
torch
|
3 |
mlflow
|
4 |
boto3
|
5 |
pytest
|
|
|
1 |
transformers==4.37
|
|
|
2 |
mlflow
|
3 |
boto3
|
4 |
pytest
|
src/deploy_sagemaker.py
CHANGED
@@ -1,91 +1,56 @@
|
|
1 |
import boto3
|
2 |
-
from pathlib import Path
|
3 |
-
import sagemaker
|
4 |
-
from sagemaker.huggingface import HuggingFaceModel
|
5 |
-
import transformers
|
6 |
-
import torch
|
7 |
import logging
|
8 |
-
import
|
|
|
|
|
9 |
import os
|
10 |
|
11 |
# Set up logging
|
12 |
logging.basicConfig(level=logging.INFO)
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
-
def
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
with tarfile.open(tar_path, "w:gz") as tar:
|
20 |
-
for file_path in model_path.glob("*"):
|
21 |
-
if file_path.is_file():
|
22 |
-
logger.info(f"Adding {file_path} to tar archive")
|
23 |
-
tar.add(file_path, arcname=file_path.name)
|
24 |
-
|
25 |
-
return tar_path
|
26 |
-
|
27 |
-
try:
|
28 |
-
# Initialize s3 client
|
29 |
-
s3 = boto3.client("s3")
|
30 |
-
bucket_name = 'customer-support-gpt'
|
31 |
-
|
32 |
-
# Create and upload tar.gz
|
33 |
-
tar_path = create_model_tar()
|
34 |
-
s3_key = "models/model.tar.gz" # Changed path
|
35 |
-
logger.info(f"Uploading model.tar.gz to s3://{bucket_name}/{s3_key}")
|
36 |
-
s3.upload_file(tar_path, bucket_name, s3_key)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# Initialize SageMaker session
|
39 |
sagemaker_session = sagemaker.Session()
|
40 |
-
role = 'arn:aws:iam::841162707028:role/service-role/AmazonSageMaker-ExecutionRole-20241109T160615'
|
41 |
-
|
42 |
-
# Verify IAM role
|
43 |
-
iam = boto3.client('iam')
|
44 |
-
try:
|
45 |
-
iam.get_role(RoleName=role.split('/')[-1])
|
46 |
-
logger.info(f"Successfully verified IAM role: {role}")
|
47 |
-
except iam.exceptions.NoSuchEntityException:
|
48 |
-
logger.error(f"IAM role not found: {role}")
|
49 |
-
raise
|
50 |
|
51 |
-
#
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
name="customer-support-gpt"
|
70 |
-
)
|
71 |
-
|
72 |
-
logger.info("Starting model deployment...")
|
73 |
-
predictor = huggingface_model.deploy(
|
74 |
-
initial_instance_count=1,
|
75 |
-
instance_type="ml.m5.xlarge",
|
76 |
-
wait=True
|
77 |
-
)
|
78 |
-
logger.info("Model deployed successfully!")
|
79 |
-
|
80 |
-
except Exception as e:
|
81 |
-
logger.error(f"Error during model deployment: {str(e)}")
|
82 |
-
raise
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
if os.path.exists(tar_path):
|
91 |
-
os.remove(tar_path)
|
|
|
1 |
import boto3
|
|
|
|
|
|
|
|
|
|
|
2 |
import logging
|
3 |
+
import sagemaker
|
4 |
+
from sagemaker.model import Model
|
5 |
+
import argparse
|
6 |
import os
|
7 |
|
8 |
# Set up logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
+
def deploy_app(acc_id, region_name, role_arn, ecr_repo_name, endpoint_name="customer-support-chatbot"):
|
13 |
+
"""
|
14 |
+
Deploys a Gradio app as a SageMaker endpoint using an ECR image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
Args:
|
17 |
+
acc_id (str): AWS account ID
|
18 |
+
region_name (str): AWS region name
|
19 |
+
role_arn (str): IAM role ARN for SageMaker
|
20 |
+
ecr_repo_name (str): ECR repository name
|
21 |
+
endpoint_name (str): SageMaker endpoint name (default: "customer-support-chatbot")
|
22 |
+
"""
|
23 |
# Initialize SageMaker session
|
24 |
sagemaker_session = sagemaker.Session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
# Define the image URI in ECR
|
27 |
+
ecr_image = f"{acc_id}.dkr.ecr.{region_name}.amazonaws.com/{ecr_repo_name}:latest"
|
28 |
+
|
29 |
+
# Define model
|
30 |
+
model = Model(
|
31 |
+
image_uri=ecr_image,
|
32 |
+
role=role_arn,
|
33 |
+
sagemaker_session=sagemaker_session
|
34 |
+
)
|
35 |
|
36 |
+
# Deploy model as a SageMaker endpoint
|
37 |
+
logger.info(f"Starting deployment of Gradio app to SageMaker endpoint {endpoint_name}...")
|
38 |
+
predictor = model.deploy(
|
39 |
+
initial_instance_count=1,
|
40 |
+
instance_type="ml.m5.xlarge",
|
41 |
+
endpoint_name=endpoint_name
|
42 |
+
)
|
43 |
+
logger.info(f"Gradio app deployed successfully to endpoint: {endpoint_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
if __name__ == "__main__":
|
46 |
+
# Parse arguments from CLI
|
47 |
+
parser = argparse.ArgumentParser(description="Deploy Gradio app to SageMaker")
|
48 |
+
parser.add_argument("--account_id", type=str, required=True, help="AWS Account ID")
|
49 |
+
parser.add_argument("--region", type=str, required=True, help="AWS Region")
|
50 |
+
parser.add_argument("--role_arn", type=str, required=True, help="IAM Role ARN for SageMaker")
|
51 |
+
parser.add_argument("--ecr_repo_name", type=str, required=True, help="ECR Repository name")
|
52 |
+
parser.add_argument("--endpoint_name", type=str, default="customer-support-chatbot", help="SageMaker Endpoint Name")
|
53 |
+
args = parser.parse_args()
|
54 |
|
55 |
+
# Deploy the Gradio app to SageMaker
|
56 |
+
deploy_app(args.account_id, args.region, args.role_arn, args.ecr_repo_name, args.endpoint_name)
|
|
|
|
src/local_deploy_sagemaker.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
from pathlib import Path
|
3 |
+
import sagemaker
|
4 |
+
from sagemaker.huggingface import HuggingFaceModel
|
5 |
+
import logging
|
6 |
+
import tarfile
|
7 |
+
import os
|
8 |
+
|
9 |
+
# Set up logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
def create_model_tar():
|
14 |
+
model_path = Path("models/customer_support_gpt")
|
15 |
+
tar_path = "model.tar.gz"
|
16 |
+
|
17 |
+
with tarfile.open(tar_path, "w:gz") as tar:
|
18 |
+
for file_path in model_path.glob("*"):
|
19 |
+
if file_path.is_file():
|
20 |
+
logger.info(f"Adding {file_path} to tar archive")
|
21 |
+
tar.add(file_path, arcname=file_path.name)
|
22 |
+
|
23 |
+
return tar_path
|
24 |
+
|
25 |
+
try:
|
26 |
+
# Initialize s3 client
|
27 |
+
s3 = boto3.client("s3")
|
28 |
+
bucket_name = 'customer-support-gpt'
|
29 |
+
|
30 |
+
# Create and upload tar.gz
|
31 |
+
tar_path = create_model_tar()
|
32 |
+
s3_key = "models/model.tar.gz" # Changed path
|
33 |
+
logger.info(f"Uploading model.tar.gz to s3://{bucket_name}/{s3_key}")
|
34 |
+
s3.upload_file(tar_path, bucket_name, s3_key)
|
35 |
+
|
36 |
+
# Initialize SageMaker session
|
37 |
+
sagemaker_session = sagemaker.Session()
|
38 |
+
role = 'arn:aws:iam::841162707028:role/service-role/AmazonSageMaker-ExecutionRole-20241109T160615'
|
39 |
+
|
40 |
+
# Verify IAM role
|
41 |
+
iam = boto3.client('iam')
|
42 |
+
try:
|
43 |
+
iam.get_role(RoleName=role.split('/')[-1])
|
44 |
+
logger.info(f"Successfully verified IAM role: {role}")
|
45 |
+
except iam.exceptions.NoSuchEntityException:
|
46 |
+
logger.error(f"IAM role not found: {role}")
|
47 |
+
raise
|
48 |
+
|
49 |
+
# Point to the tar.gz file
|
50 |
+
model_artifacts = f's3://{bucket_name}/{s3_key}'
|
51 |
+
print(f'Model artifacts: {model_artifacts}')
|
52 |
+
|
53 |
+
env = {
|
54 |
+
"model_path": "/opt/ml/model",
|
55 |
+
"max_length": "256",
|
56 |
+
"generation_config": '{"max_length":100,"temperature":0.7,"top_p":0.95,"top_k":50,"do_sample":true}'
|
57 |
+
}
|
58 |
+
|
59 |
+
try:
|
60 |
+
huggingface_model = HuggingFaceModel(
|
61 |
+
model_data=model_artifacts,
|
62 |
+
role=role,
|
63 |
+
transformers_version="4.37.0", # Explicit version
|
64 |
+
pytorch_version="2.1.0", # Matching your version
|
65 |
+
py_version="py310", # Keep py310
|
66 |
+
env=env,
|
67 |
+
name="customer-support-gpt"
|
68 |
+
)
|
69 |
+
|
70 |
+
logger.info("Starting model deployment...")
|
71 |
+
predictor = huggingface_model.deploy(
|
72 |
+
initial_instance_count=1,
|
73 |
+
instance_type="ml.m5.xlarge",
|
74 |
+
wait=True
|
75 |
+
)
|
76 |
+
logger.info("Model deployed successfully!")
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"Error during model deployment: {str(e)}")
|
80 |
+
raise
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Deployment failed: {str(e)}")
|
84 |
+
raise
|
85 |
+
|
86 |
+
finally:
|
87 |
+
# Clean up the local tar file
|
88 |
+
if os.path.exists(tar_path):
|
89 |
+
os.remove(tar_path)
|
src/push_to_s3.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
from pathlib import Path
|
3 |
+
import tarfile
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Set up logging
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
def create_model_tar():
|
12 |
+
model_path = Path("models/customer_support_gpt") # Path to your model folder
|
13 |
+
tar_path = "model.tar.gz" # Path for the output tar.gz file
|
14 |
+
|
15 |
+
# Create a tar.gz file containing all files in the model folder
|
16 |
+
with tarfile.open(tar_path, "w:gz") as tar:
|
17 |
+
for file_path in model_path.glob("*"):
|
18 |
+
if file_path.is_file():
|
19 |
+
logger.info(f"Adding {file_path} to tar archive")
|
20 |
+
tar.add(file_path, arcname=file_path.name)
|
21 |
+
|
22 |
+
return tar_path
|
23 |
+
|
24 |
+
def upload_to_s3(tar_path, bucket_name, s3_key):
|
25 |
+
# Initialize S3 client
|
26 |
+
s3 = boto3.client("s3")
|
27 |
+
|
28 |
+
# Upload tar.gz file to S3
|
29 |
+
logger.info(f"Uploading {tar_path} to s3://{bucket_name}/{s3_key}")
|
30 |
+
s3.upload_file(tar_path, bucket_name, s3_key)
|
31 |
+
logger.info("Upload complete!")
|
32 |
+
|
33 |
+
# Main code
|
34 |
+
try:
|
35 |
+
bucket_name = 'customer-support-gpt' # Your S3 bucket name
|
36 |
+
s3_key = "models/model.tar.gz" # S3 key (path in bucket)
|
37 |
+
|
38 |
+
# Create the tar.gz archive
|
39 |
+
tar_path = create_model_tar()
|
40 |
+
|
41 |
+
# Upload the tar.gz to S3
|
42 |
+
upload_to_s3(tar_path, bucket_name, s3_key)
|
43 |
+
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(f"An error occurred: {str(e)}")
|
46 |
+
raise
|
47 |
+
|
48 |
+
finally:
|
49 |
+
# Clean up the local tar file
|
50 |
+
if os.path.exists(tar_path):
|
51 |
+
os.remove(tar_path)
|
52 |
+
logger.info(f"Deleted local file: {tar_path}")
|