Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions run_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,55 @@
import platform
import argparse
import subprocess
import re
import ipaddress

def run_command(command, shell=False):
"""Run a system command and ensure it succeeds."""
def validate_path(path):
"""Validate that a path is safe and doesn't contain command injection attempts."""
# Normalize the path to prevent directory traversal
normalized = os.path.normpath(path)
# Check for suspicious patterns that might indicate command injection
suspicious_patterns = [';', '&', '|', '$', '`', '\n', '\r', '>', '<', '(', ')']
if any(char in normalized for char in suspicious_patterns):
raise ValueError(f"Invalid characters detected in path: {path}")
return normalized

def validate_ip_address(ip):
"""Validate that the IP address is valid."""
try:
ipaddress.ip_address(ip)
return ip
except ValueError:
raise ValueError(f"Invalid IP address: {ip}")

def validate_prompt(prompt):
"""Validate prompt to prevent command injection."""
# Check for suspicious patterns in prompt
suspicious_patterns = ['$(', '`', '|', ';', '&', '\n', '\r']
if any(pattern in prompt for pattern in suspicious_patterns):
raise ValueError(f"Invalid characters detected in prompt")
return prompt

def run_command(command):
"""Run a system command safely without shell=True."""
try:
subprocess.run(command, shell=shell, check=True)
# Force shell=False to prevent command injection
subprocess.run(command, shell=False, check=True)
except subprocess.CalledProcessError as e:
print(f"Error occurred while running command: {e}")
sys.exit(1)

def run_server():
# Validate all user inputs before using them
try:
validated_model = validate_path(args.model)
validated_host = validate_ip_address(args.host)
if args.prompt:
validated_prompt = validate_prompt(args.prompt)
except ValueError as e:
print(f"Validation error: {e}")
sys.exit(1)

build_dir = "build"
if platform.system() == "Windows":
server_path = os.path.join(build_dir, "bin", "Release", "llama-server.exe")
Expand All @@ -24,23 +63,23 @@ def run_server():

command = [
f'{server_path}',
'-m', args.model,
'-m', validated_model,
'-c', str(args.ctx_size),
'-t', str(args.threads),
'-n', str(args.n_predict),
'-ngl', '0',
'--temp', str(args.temperature),
'--host', args.host,
'--host', validated_host,
'--port', str(args.port),
'-cb' # Enable continuous batching
]

if args.prompt:
command.extend(['-p', args.prompt])
command.extend(['-p', validated_prompt])

# Note: -cnv flag is removed as it's not supported by the server

print(f"Starting server on {args.host}:{args.port}")
print(f"Starting server on {validated_host}:{args.port}")
run_command(command)

def signal_handler(sig, frame):
Expand Down