From 0d615219e58db06831abb81974ab3c37147f7c60 Mon Sep 17 00:00:00 2001 From: BIRK Renaud <renaud.birk@imt-atlantique.net> Date: Tue, 18 Feb 2025 19:58:00 +0100 Subject: [PATCH] feat: API GET /suggest --- src/controllers/suggest_controller.py | 26 ++++++++++++++++ src/main.py | 2 +- src/routes/glb_route.py | 23 -------------- src/routes/router.py | 43 +++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 24 deletions(-) create mode 100644 src/controllers/suggest_controller.py delete mode 100644 src/routes/glb_route.py create mode 100644 src/routes/router.py diff --git a/src/controllers/suggest_controller.py b/src/controllers/suggest_controller.py new file mode 100644 index 0000000..202da39 --- /dev/null +++ b/src/controllers/suggest_controller.py @@ -0,0 +1,26 @@ +import os +from gradio_client import Client, handle_file + +class SuggestController: + def get_suggest(self, hf_token='', message='', system_prompt='', temperature=.0, + max_new_tokens=0, top_k_sampling=0, repetition_penalty=.0, + top_p_sampling=.0) -> str: + if hf_token == '': + hf_token = os.getenv("hf_token") + + client = Client("mmcgovern574/Mistral-Small-24B-Instruct-2501", hf_token) + + result = client.predict( + message=message, + param_2=system_prompt, + param_3=temperature, + param_4=max_new_tokens, + param_5=top_k_sampling, + param_6=repetition_penalty, + param_7=top_p_sampling, + api_name="/chat" + ) + + print(result) + + return result diff --git a/src/main.py b/src/main.py index e4ce297..a63b876 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,6 @@ from dotenv import load_dotenv from flask import Flask -from routes.glb_route import setup_routes +from routes.router import setup_routes load_dotenv() diff --git a/src/routes/glb_route.py b/src/routes/glb_route.py deleted file mode 100644 index 8c06398..0000000 --- a/src/routes/glb_route.py +++ /dev/null @@ -1,23 +0,0 @@ -from flask import Blueprint, request -from controllers.glb_controller import GlbController - -glb_bp = Blueprint('glb', __name__) -controller = GlbController() - -def setup_routes(app): - @app.route('/glb', methods=['GET']) - def get_glb(): - base_image_url = request.args.get('base_image_url', type=str) - hf_token = request.args.get('hf_token', '', type=str) - ss_guidance_strength = request.args.get('ss_guidance_strength', 7.5, type=float) - ss_sampling_steps = request.args.get('ss_sampling_steps', 12, type=int) - slat_guidance_strength = request.args.get('slat_guidance_strength', 3, type=float) - slat_sampling_steps = request.args.get('slat_sampling_steps', 12, type=int) - multiimage_algo = request.args.get('multiimage_algo', 'stochastic', type=str) - mesh_simplify = request.args.get('mesh_simplify', 0.95, type=bool) - texture_size = request.args.get('texture_size', 1024, type=int) - - return controller.get_glb(hf_token, base_image_url, ss_guidance_strength, - ss_sampling_steps, slat_guidance_strength, - slat_sampling_steps, multiimage_algo, - mesh_simplify, texture_size) diff --git a/src/routes/router.py b/src/routes/router.py new file mode 100644 index 0000000..ea9e058 --- /dev/null +++ b/src/routes/router.py @@ -0,0 +1,43 @@ +from flask import Blueprint, request +from controllers.glb_controller import GlbController +from controllers.suggest_controller import SuggestController + +glb_bp = Blueprint('glb', __name__) +suggest_bp = Blueprint('suggest', __name__) + +glb_controller = GlbController() +suggest_controller = SuggestController() + +def setup_routes(app): + @app.route('/glb', methods=['GET']) + def get_glb(): + base_image_url = request.args.get('base_image_url', type=str) + hf_token = request.args.get('hf_token', '', type=str) + ss_guidance_strength = request.args.get('ss_guidance_strength', 7.5, type=float) + ss_sampling_steps = request.args.get('ss_sampling_steps', 12, type=int) + slat_guidance_strength = request.args.get('slat_guidance_strength', 3, type=float) + slat_sampling_steps = request.args.get('slat_sampling_steps', 12, type=int) + multiimage_algo = request.args.get('multiimage_algo', 'stochastic', type=str) + mesh_simplify = request.args.get('mesh_simplify', 0.95, type=bool) + texture_size = request.args.get('texture_size', 1024, type=int) + + return glb_controller.get_glb(hf_token, base_image_url, ss_guidance_strength, + ss_sampling_steps, slat_guidance_strength, + slat_sampling_steps, multiimage_algo, + mesh_simplify, texture_size) + + @app.route('/suggest', methods=['GET']) + def get_suggest(): + hf_token = request.args.get('hf_token', '', type=str) + message = request.args.get('message', 'Hello!!', type=str) + system_prompt = request.args.get('system_prompt', 'You are a helpful assistant. First recognize the user request and then reply carefully with thinking.', type=str) + temperature = request.args.get('temperature', 0.7, type=float) + max_new_tokens = request.args.get('max_new_tokens', 12000, type=int) + top_k_sampling = request.args.get('top_k_sampling', 50, type=int) + repetition_penalty = request.args.get('repetition_penalty', 1.1, type=float) + top_p_sampling = request.args.get('top_p_sampling', 0.95, type=float) + + return suggest_controller.get_suggest(hf_token, message, system_prompt, + temperature, max_new_tokens, + top_k_sampling, repetition_penalty, + top_p_sampling) -- GitLab