l’evaluation viene fatta grazie ad un algoritmo che confronta l’entità trovata da kosmos (con il suo relativo bounding box) con l’immagine labeled

Labeled Image:

in questo file è contenuta, sotto forma di xml, il labeling di un’immagine, a cui è stato associato il bounding box per una televisione.

Image:

questa invece è l’immagine originale.

L’algoritmo

L’algoritmo prende in input queste due cose e ci dice se Kosmos 2 è in grado di generare un bonding box (in questo caso per l’entità “television”) accettabile. Per il codice completo (del main) si rimanda qui

passo 1: estrazione delle entità

in questo passo viene preso il file xml e viene estratta l’entità a cui dobbiamo fare riferimento (le entità in caso siano più di una).

 
# Step 2: Extract entities from the labeled image
 
labeled_entities = extract_entities_from_xml(labeled_image_path)
 
#print(labeled_entities)
def extract_entities_from_xml(xml_file):
	
	entities = []	
	tree = ET.parse(xml_file)	
	root = tree.getroot()	  
	
	object_elements = root.findall(".//object")
	
	for object_element in object_elements:
		entity_name = object_element.find("name").text
		xmin = float(object_element.find("bndbox/xmin").text)
		ymin = float(object_element.find("bndbox/ymin").text)
		xmax = float(object_element.find("bndbox/xmax").text)
		ymax = float(object_element.find("bndbox/ymax").text)
	
	# Normalize coordinates
	
	xmin_normalized = xmin / image_width
	ymin_normalized = ymin / image_height
	xmax_normalized = xmax / image_width
	ymax_normalized = ymax / image_height
	
	entities.append((entity_name, (xmin_normalized, ymin_normalized, xmax_normalized, ymax_normalized)))
	  
 
return entities

il risultato di questa fase è una lista di entità con il relativo bounding box, nel nostro caso:

[('television', (0.4, 0.275, 0.6266666666666667, 0.435))]

passo 2: generare il prompt da passare a kosmos

Kosmos ha bisogno dell’immagine, ma anche di un prompt che gli dica “hey, kosmos, mi faresti il grounding della televisione?” Ovviamente questo dipende da qual’è l’entità di cui dobbiamo fare il grounding: quindi lo generiamo a partire dall’output del passo precedente (che tra le altre cose ci aveva dato il nome dell’entità):

Questo lo step nel main:

# Step 2: Extract entities from the labeled image
 
labeled_entities = extract_entities_from_xml(labeled_image_path)
 
print(labeled_entities)

questa la funzione “extract_entities_from_xml”:

def extract_entities_from_xml(xml_file):
 
	entities = []
	tree = ET.parse(xml_file)
	root = tree.getroot()
	
	object_elements = root.findall(".//object")
	for object_element in object_elements:
		entity_name = object_element.find("name").text
		xmin = float(object_element.find("bndbox/xmin").text)
		ymin = float(object_element.find("bndbox/ymin").text)
		xmax = float(object_element.find("bndbox/xmax").text)
		ymax = float(object_element.find("bndbox/ymax").text)
	# Normalize coordinates
	xmin_normalized = xmin / image_width
	ymin_normalized = ymin / image_height
	xmax_normalized = xmax / image_width
	ymax_normalized = ymax / image_height
	
	entities.append((entity_name, (xmin_normalized, ymin_normalized, xmax_normalized, ymax_normalized)))
 
  
 
return entities

il prompt per dire a Kosmos: “trovami in quest’immagine il bounding box di television” è:

"<grounding><phrase>television</phrase>"

passo 3: Dire a kosmos di generare il bounding box dell’entità

A questo punto kosmos ha tutto quello che gli serve per generare il bounding box di “television”.

questo il codice nel main:

# Step 4: Process image with Kosmos and get entities
 
kosmos_entities = process_image_with_kosmos(prompts[0], image_path)
 
print(kosmos_entities)

questo il codice nella funzione “process_image_with_kosmos”:

def process_image_with_kosmos(prompt, image_path):
	
	image = Image.open(image_path)
	inputs = processor(text=prompt, images=image, return_tensors="pt")
	
	generated_ids = model.generate(
		pixel_values=inputs["pixel_values"],
		input_ids=inputs["input_ids"][:, :-1],
		attention_mask=inputs["attention_mask"][:, :-1],
		img_features=None,
		img_attn_mask=inputs["img_attn_mask"][:, :-1],
		use_cache=True,
		max_new_tokens=64,
	)
	generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 
	
	processed_text, kosmos_entities = processor.post_process_generation(generated_text)
	
	print(kosmos_entities)
 
  
 
return kosmos_entities

Da Kosmos possiamo farci avere molte cose in output, ma ciò che serve a noi è il bounding box dell’entità “television”. Kosmos da questa, insieme da altre informazioni nell’output di “kosmos_entities”

[('television', (0, 10), [(0.421875, 0.296875, 0.609375, 0.453125)]), ('sofa', (15, 19), [(0.015625, 0.390625, 0.515625, 0.953125)])]

Come si può vedere Kosmos (completely unsolicited) ci ha dato anche il bounding box di qualcos’altro (in questo caso sofa), questo non è un problema perchè l’importante è che ci dia almeno quello di “television”, ma va preso in considerazione perchè lo fa spesso.

passo4: Confrontiamo l’output di kosmos con il labeled data

ora dobbiamo capire se kosmos ha fatto qualcosa di accettabile oppure no?

cosa consideriamo accettabile?

Accettabile è il caso di considerare quando, kosmos fa correttamente il bounding box dell’entità che gli abbiamo chiesto quindi serve che

  1. I Bounding box corrispondano con un’entità di quelle che ci ha dato kosmos
  2. che l’entità abbia il nome giusto In poche parole: non ci interessa che kosmos abbia individuato correttamente il bounding box se poi quel bounding box mi dice che è una “fatina dei denti” e non il televisore che gli avevo chiesto di trovare.
# Step 5: Compare entities and calculate overlapping index
 
for labeled_entity in labeled_entities:
	for kosmos_entity in kosmos_entities:	
		entity_name, labeled_box = labeled_entity	
		entity_name_kosmos, _, kosmos_boxes = kosmos_entity		
		for kosmos_box in kosmos_boxes:		
			overlap_index = overlapping_index(labeled_box, kosmos_box)		
			if overlap_index > 0.5:			
				if(entity_name!=entity_name_kosmos): print(f"The entities match for {entity_name} and {entity_name_kosmos} from kosmos, but the names are different")				
				else:				
					print(f"The entities match for \"{entity_name}\" and \"{entity_name_kosmos}\" from kosmos with an overlapping index of: {overlap_index}")				
				break

in questo codice prendiamo ogni entità (di quelle assegnate sul file xml, quindi normalmente una sola), poi prendiamo ogni entità che ha restituito kosmos, e ne confrontiamo l’overlap index.

L’overlap index è una misura che calcola quanto si sovrappongono i due bounding box. Generalmente in letteratura, se i due bounding box si sovrappongono per più del 50% allora si considera una corrispondenza.

def overlap_area(box1, box2):
 
"""
 
Calculate the overlapping area between two bounding boxes.
 
The input boxes should be in the format (xmin, ymin, xmax, ymax).
 
"""
	
	x_min = max(box1[0], box2[0])	
	y_min = max(box1[1], box2[1])	
	x_max = min(box1[2], box2[2])	
	y_max = min(box1[3], box2[3])
		
	width = max(0, x_max - x_min)	
	height = max(0, y_max - y_min)
		
	return width * height
 
 
def overlapping_index(box1, box2):
	
"""
 
Calculate the overlapping index (IoU) between two bounding boxes.	
The input boxes should be in the format (xmin, ymin, xmax, ymax).
 
"""
 
	area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])   
	area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
	
	overlap = overlap_area(box1, box2)
	union = area1 + area2 - overlap
	
	return overlap / union if union > 0 else 0

Output dell’esempio

alla fine abbiamo avuto un match tra i due bounding box:

-----------------------------
The entities match for "television" and "television" from kosmos with an overlapping index of: 0.6529275050225192
 
----------------------------

codice del main completo:

def main():
 
# Step 1: Get user input for image and labeled image
 
#image_path = input("Enter the path to the image file (e.g., IMG_1.png): ")
 
image_path="/content/drive/MyDrive/prova_kosmos/img_3.PNG"
 
#labeled_image_path = input("Enter the path to the labeled image file (e.g., labeled_image.xml): ")
 
labeled_image_path="/content/drive/MyDrive/prova_kosmos/labeled_img_3.xml"
 
  
 
# Step 2: Extract entities from the labeled image
 
labeled_entities = extract_entities_from_xml(labeled_image_path)
 
#print(labeled_entities)
 
  
 
# Step 3: Generate prompts from extracted entities
 
prompts = generate_prompt(labeled_entities)
 
#print(prompts[0])
 
  
 
# Step 4: Process image with Kosmos and get entities
 
kosmos_entities = process_image_with_kosmos(prompts[0], image_path)
 
#print(kosmos_entities)
 
  
 
# Step 5: Compare entities and calculate overlapping index
 
for labeled_entity in labeled_entities:
 
for kosmos_entity in kosmos_entities:
 
entity_name, labeled_box = labeled_entity
 
entity_name_kosmos, _, kosmos_boxes = kosmos_entity
 
for kosmos_box in kosmos_boxes:
 
overlap_index = overlapping_index(labeled_box, kosmos_box)
 
if overlap_index > 0.5:
 
if(entity_name!=entity_name_kosmos): print(f"The entities match for {entity_name} and {entity_name_kosmos} from kosmos, but the names are different")
 
else:
 
print(f"The entities match for \"{entity_name}\" and \"{entity_name_kosmos}\" from kosmos with an overlapping index of: {overlap_index}")
 
break
 
  
 
if __name__ == "__main__":
 
main()