From 0d615219e58db06831abb81974ab3c37147f7c60 Mon Sep 17 00:00:00 2001
From: BIRK Renaud <renaud.birk@imt-atlantique.net>
Date: Tue, 18 Feb 2025 19:58:00 +0100
Subject: [PATCH] feat: API GET /suggest

---
 src/controllers/suggest_controller.py | 26 ++++++++++++++++
 src/main.py                           |  2 +-
 src/routes/glb_route.py               | 23 --------------
 src/routes/router.py                  | 43 +++++++++++++++++++++++++++
 4 files changed, 70 insertions(+), 24 deletions(-)
 create mode 100644 src/controllers/suggest_controller.py
 delete mode 100644 src/routes/glb_route.py
 create mode 100644 src/routes/router.py

diff --git a/src/controllers/suggest_controller.py b/src/controllers/suggest_controller.py
new file mode 100644
index 0000000..202da39
--- /dev/null
+++ b/src/controllers/suggest_controller.py
@@ -0,0 +1,26 @@
+import os
+from gradio_client import Client, handle_file
+
+class SuggestController:
+    def get_suggest(self, hf_token='', message='', system_prompt='', temperature=.0,
+                    max_new_tokens=0, top_k_sampling=0, repetition_penalty=.0,
+                    top_p_sampling=.0) -> str:
+        if hf_token == '':
+            hf_token = os.getenv("hf_token")
+
+        client = Client("mmcgovern574/Mistral-Small-24B-Instruct-2501", hf_token)
+
+        result = client.predict(
+            message=message,
+            param_2=system_prompt,
+            param_3=temperature,
+            param_4=max_new_tokens,
+            param_5=top_k_sampling,
+            param_6=repetition_penalty,
+            param_7=top_p_sampling,
+            api_name="/chat"
+        )
+
+        print(result)
+
+        return result
diff --git a/src/main.py b/src/main.py
index e4ce297..a63b876 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,6 +1,6 @@
 from dotenv import load_dotenv
 from flask import Flask
-from routes.glb_route import setup_routes
+from routes.router import setup_routes
 
 load_dotenv()
 
diff --git a/src/routes/glb_route.py b/src/routes/glb_route.py
deleted file mode 100644
index 8c06398..0000000
--- a/src/routes/glb_route.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from flask import Blueprint, request
-from controllers.glb_controller import GlbController
-
-glb_bp = Blueprint('glb', __name__)
-controller = GlbController()
-
-def setup_routes(app):
-    @app.route('/glb', methods=['GET'])
-    def get_glb():
-        base_image_url = request.args.get('base_image_url', type=str)
-        hf_token = request.args.get('hf_token', '', type=str)
-        ss_guidance_strength = request.args.get('ss_guidance_strength', 7.5, type=float)
-        ss_sampling_steps = request.args.get('ss_sampling_steps', 12, type=int)
-        slat_guidance_strength = request.args.get('slat_guidance_strength', 3, type=float)
-        slat_sampling_steps = request.args.get('slat_sampling_steps', 12, type=int)
-        multiimage_algo = request.args.get('multiimage_algo', 'stochastic', type=str)
-        mesh_simplify = request.args.get('mesh_simplify', 0.95, type=bool)
-        texture_size = request.args.get('texture_size', 1024, type=int)
-
-        return controller.get_glb(hf_token, base_image_url, ss_guidance_strength,
-                                  ss_sampling_steps, slat_guidance_strength,
-                                  slat_sampling_steps, multiimage_algo,
-                                  mesh_simplify, texture_size)
diff --git a/src/routes/router.py b/src/routes/router.py
new file mode 100644
index 0000000..ea9e058
--- /dev/null
+++ b/src/routes/router.py
@@ -0,0 +1,43 @@
+from flask import Blueprint, request
+from controllers.glb_controller import GlbController
+from controllers.suggest_controller import SuggestController
+
+glb_bp = Blueprint('glb', __name__)
+suggest_bp = Blueprint('suggest', __name__)
+
+glb_controller = GlbController()
+suggest_controller = SuggestController()
+
+def setup_routes(app):
+    @app.route('/glb', methods=['GET'])
+    def get_glb():
+        base_image_url = request.args.get('base_image_url', type=str)
+        hf_token = request.args.get('hf_token', '', type=str)
+        ss_guidance_strength = request.args.get('ss_guidance_strength', 7.5, type=float)
+        ss_sampling_steps = request.args.get('ss_sampling_steps', 12, type=int)
+        slat_guidance_strength = request.args.get('slat_guidance_strength', 3, type=float)
+        slat_sampling_steps = request.args.get('slat_sampling_steps', 12, type=int)
+        multiimage_algo = request.args.get('multiimage_algo', 'stochastic', type=str)
+        mesh_simplify = request.args.get('mesh_simplify', 0.95, type=bool)
+        texture_size = request.args.get('texture_size', 1024, type=int)
+
+        return glb_controller.get_glb(hf_token, base_image_url, ss_guidance_strength,
+                                  ss_sampling_steps, slat_guidance_strength,
+                                  slat_sampling_steps, multiimage_algo,
+                                  mesh_simplify, texture_size)
+    
+    @app.route('/suggest', methods=['GET'])
+    def get_suggest():
+        hf_token = request.args.get('hf_token', '', type=str)
+        message = request.args.get('message', 'Hello!!', type=str)
+        system_prompt = request.args.get('system_prompt', 'You are a helpful assistant. First recognize the user request and then reply carefully with thinking.', type=str)
+        temperature = request.args.get('temperature', 0.7, type=float)
+        max_new_tokens = request.args.get('max_new_tokens', 12000, type=int)
+        top_k_sampling = request.args.get('top_k_sampling', 50, type=int)
+        repetition_penalty = request.args.get('repetition_penalty', 1.1, type=float)
+        top_p_sampling = request.args.get('top_p_sampling', 0.95, type=float)
+
+        return suggest_controller.get_suggest(hf_token, message, system_prompt,
+                                              temperature, max_new_tokens,
+                                              top_k_sampling, repetition_penalty,
+                                              top_p_sampling)
-- 
GitLab